From 92a1e0fbae2bf7dd01206d15d4e6777462b55b16 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Tue, 24 Sep 2024 03:06:55 +0000 Subject: [PATCH 001/340] Aligin the source code with main branch 2.0.4 Signed-off-by: yuanwu --- benchmark/src/generation.rs | 6 +- benchmark/src/main.rs | 2 +- router/client/Cargo.toml | 1 - router/client/src/client.rs | 368 +-- router/client/src/sharded_client.rs | 4 - router/src/infer.rs | 13 +- router/src/main.rs | 33 +- router/src/queue.rs | 211 +- router/src/server.rs | 4 - router/src/validation.rs | 50 +- server/Makefile | 14 +- server/dill-0.3.7-patch.sh | 91 - server/dill-0.3.8-patch.sh | 91 - server/poetry.lock | 2821 ++++++++--------- server/pyproject.toml | 46 +- server/requirements.txt | 88 - server/tests/models/test_bloom.py | 43 +- server/tests/models/test_causal_lm.py | 290 +- server/tests/models/test_grammar.py | 245 -- server/tests/models/test_seq2seq_lm.py | 7 - server/tests/models/test_starcoder.py | 372 --- server/tests/utils/test_tokens.py | 44 - server/tests/utils/test_watermark.py | 85 +- server/text_generation_server/cli.py | 121 +- .../habana_quantization_env.py | 27 - server/text_generation_server/interceptor.py | 9 +- .../layers/gptq/exllamav2.py | 2 + .../layers/gptq/exllamav2.py.rej | 10 - .../layers/gptq/quant_linear.py | 712 ++--- .../text_generation_server/models/__init__.py | 783 ++++- server/text_generation_server/models/bloom.py | 94 +- .../models/causal_lm.py | 1508 ++++----- .../models/custom_modeling/llava_next.py | 427 ++- .../models/flash_mistral.py | 2 +- server/text_generation_server/models/model.py | 28 +- .../models/santacoder.py | 47 +- .../models/starcoder.py | 47 - .../models/vlm_causal_lm.py | 1150 +------ server/text_generation_server/server.py | 125 +- server/text_generation_server/tgi_service.py | 37 - .../text_generation_server/utils/__init__.py | 6 - server/text_generation_server/utils/debug.py | 31 - server/text_generation_server/utils/dist.py | 11 - .../utils/gptq/quant_linear.py | 359 --- .../utils/logits_process.py | 95 +- server/text_generation_server/utils/tokens.py | 125 +- .../text_generation_server/utils/version.py | 12 - .../text_generation_server/utils/watermark.py | 26 +- 48 files changed, 3972 insertions(+), 6751 deletions(-) delete mode 100644 server/dill-0.3.7-patch.sh delete mode 100644 server/dill-0.3.8-patch.sh delete mode 100644 server/requirements.txt delete mode 100644 server/tests/models/test_grammar.py delete mode 100644 server/tests/models/test_starcoder.py delete mode 100644 server/text_generation_server/habana_quantization_env.py delete mode 100644 server/text_generation_server/layers/gptq/exllamav2.py.rej delete mode 100644 server/text_generation_server/models/starcoder.py delete mode 100644 server/text_generation_server/tgi_service.py delete mode 100644 server/text_generation_server/utils/debug.py delete mode 100644 server/text_generation_server/utils/gptq/quant_linear.py delete mode 100644 server/text_generation_server/utils/version.py diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b2766d0c..ea7c9778 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -163,11 +163,8 @@ async fn prefill( // Run prefill let start_time = Instant::now(); - let (_, decode_batch, _) = client.prefill(batch.clone()).await?; - let (_, decode_batch, _) = client.decode(vec![decode_batch.clone().unwrap()]).await?; - // Get latency let latency = start_time.elapsed(); @@ -183,12 +180,11 @@ async fn prefill( }; Ok((step, decode_batch)) - } /// Run a full decode async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result { - let mut decode_length = 1; // 1 decode step was already scheduled in prefill with speculative scheduling + let mut decode_length = 0; let batch_size = batch.size; let start_time = Instant::now(); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 34b91a92..2d89e045 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -51,7 +51,7 @@ struct Args { runs: usize, /// Number of warmup cycles - #[clap(default_value = "3", short, long, env)] + #[clap(default_value = "1", short, long, env)] warmups: usize, /// The location of the grpc socket. This benchmark tool bypasses the router diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index bc4ae72e..d0131784 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -9,7 +9,6 @@ homepage.workspace = true futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } prost = "^0.12" -rand = "0.8.5" thiserror = "^1.0" tokio = { version = "^1.32", features = ["sync"] } tonic = "^0.10" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 4a0d44ee..e8035106 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,13 +1,9 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - /// Single shard Client use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v2::*; use crate::Result; -use std::env; -use rand::{distributions::Uniform, Rng}; use grpc_metadata::InjectTelemetryContext; -use std::cmp; +use std::cmp::min; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -110,344 +106,72 @@ impl Client { max_prefill_tokens: u32, max_total_tokens: u32, max_batch_size: Option, - model_id: &str ) -> Result> { - let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); - if !warmup_enabled { - return Ok(None); - } - - let read_env_var = |key: &str, default: u32| -> u32 { - env::var(key).ok().map_or(default, |value| value.parse::().unwrap()) - }; - - // get all possible batch sizes - let decode_bucket_size: u32 = read_env_var("BATCH_BUCKET_SIZE", 8); - let max_decode_batch_size: u32 = match max_batch_size { - Some(max_batch_size) => max_batch_size as u32, - None => decode_bucket_size - }; - let decode_batch_sizes: Vec = (decode_bucket_size..max_decode_batch_size+1).step_by(decode_bucket_size as usize).collect(); - - let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4); - let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length; - max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size); - let prefill_batch_sizes: Vec = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect(); - - // get all possible sequence lengths for prefill - let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); - let mut seq_lengths: Vec = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect(); - if let Some(&last) = seq_lengths.last() { - if last < max_input_length { - seq_lengths.push(max_input_length); - } - } - - // execute batch for each combination of batch size and sequence length - let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len() * seq_lengths.len()); - for batch_size in &prefill_batch_sizes { - for seq_length in &seq_lengths { - shapes.push((*batch_size, *seq_length)); - } - } - - let mut batch_counter: u64 = 0; - let mut request_counter: u64 = 0; - if model_id.contains("llava") { - let mut n_tokens = 0; - let mut requests = Vec::new(); - // Create requests - while n_tokens < max_prefill_tokens { - let truncate = cmp::min(max_input_length, max_prefill_tokens - n_tokens); - - let mut inputs = String::new(); - inputs.push_str("![]()"); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); - - requests.push(Request { - id: 0, - // We truncate the input on the server side to be sure that it has the correct size - inputs, - truncate, - // Set sampling parameters to also take these ops into account in the max memory - parameters: Some(NextTokenChooserParameters { - temperature: 0.9, - top_k: 10, - top_p: 0.9, - typical_p: 0.9, - do_sample: false, - seed: 0, - repetition_penalty: 1.2, - frequency_penalty: 0.1, - watermark: true, - grammar: String::new(), - grammar_type: GrammarType::None as i32, - }), - stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: max_total_tokens - truncate, - stop_sequences: vec![], - ignore_eos_token: true, - }), - prefill_logprobs: true, - top_n_tokens: 20, - }); - n_tokens += max_input_length; - - // Check max_batch_size - if Some(requests.len()) == max_batch_size { - break; - } - } - - let mut batches = Vec::new(); - batches.push(Batch { - id: 0, - size: requests.len() as u32, - requests, - max_tokens: 0, - }); - - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }) - .inject_context(); - let response = self.stub.warmup(request).await?.into_inner(); - Ok(response.max_supported_total_tokens) - } - else { - for shape in shapes.iter() { - let (batch_size, seq_length) = shape; - let mut batches: Vec = vec![ - self.create_warmup_batch( - *shape, - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - false, - None, - ) - ]; - // if possible, create second batch in order to trigger concatenate operation - if *batch_size < max_decode_batch_size { - batches.push( - self.create_warmup_batch( - (1, *seq_length), - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - false, - None, - ) - ); - } - - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }).inject_context(); - let _response = self.stub.warmup(request).await?.into_inner(); - } - - // send batches to warmup all possible decode shapes - if decode_batch_sizes.len() > 1 { - let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size { - decode_bucket_size - } else { - decode_bucket_size.div_ceil(max_prefill_batch_size) - }; - let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket; - - let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size); - let mut batches: Vec = vec![ - self.create_warmup_batch( - (requests_send, seq_bucket_size), - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - false, - Some(max_new_tokens), - ) - ]; - - let get_current_decode_batch_size = |num: u32| -> u32 { - decode_batch_sizes.iter() - .filter(|&&x| x >= num) - .min() - .copied() - .unwrap() - }; - - let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send); - while current_decode_batch_size < max_decode_batch_size { - let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send; - let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size); - batches.push( - self.create_warmup_batch( - (num_requests, seq_bucket_size), - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - false, - Some(max_new_tokens), - ) - ); - - requests_send += num_requests; - current_decode_batch_size = get_current_decode_batch_size(requests_send); - } - - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }).inject_context(); - let _response = self.stub.warmup(request).await?.into_inner(); - } - - // send batches with default params to warm up Greedy search - let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len()); - for batch_size in &prefill_batch_sizes { - greedy_shapes.push((*batch_size, seq_bucket_size.clone())); - } - for greedy_shape in greedy_shapes.iter() { - let batches: Vec = vec![ - self.create_warmup_batch( - *greedy_shape, - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - true, - None, - ) - ]; - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }).inject_context(); - let _response = self.stub.warmup(request).await?.into_inner(); - } - Ok(None) // No support for maximum total tokens - } - } - - #[instrument(skip_all)] - fn create_warmup_batch( - &mut self, - shape: (u32, u32), - batch_counter: &mut u64, - request_counter: &mut u64, - max_input_length: u32, - max_total_tokens: u32, - seq_bucket_size: u32, - default_params: bool, - max_new_tokens: Option, - ) -> Batch { - *batch_counter += 1; - let (batch_size, input_length) = shape; + let mut n_tokens = 0; let mut requests = Vec::new(); - for _ in 0..batch_size { - *request_counter += 1; - let req_params = if default_params { - Some(NextTokenChooserParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - typical_p: 1.0, - do_sample: false, - seed: 0, - repetition_penalty: 1.0, - frequency_penalty: 0.0, - watermark: false, - grammar: String::new(), - grammar_type: GrammarType::None as i32, - }) - } else { - Some(NextTokenChooserParameters { + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str("![]()"); + } + + requests.push(Request { + id: 0, + // We truncate the input on the server side to be sure that it has the correct size + inputs, + truncate, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { temperature: 0.9, top_k: 10, top_p: 0.9, typical_p: 0.9, - do_sample: true, + do_sample: false, seed: 0, repetition_penalty: 1.2, frequency_penalty: 0.1, - watermark: false, + watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, - }) - }; - requests.push(Request { - id: *request_counter, - inputs: self.get_random_input(input_length, seq_bucket_size), - truncate: max_input_length, - parameters: req_params, + }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: cmp::min(max_new_tokens.unwrap_or(10), max_total_tokens - max_input_length), + max_new_tokens: max_total_tokens - truncate, stop_sequences: vec![], ignore_eos_token: true, }), - prefill_logprobs: false, - top_n_tokens: 0, + prefill_logprobs: true, + top_n_tokens: 20, }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } } - Batch { - id: *batch_counter, + let batch = Batch { + id: 0, size: requests.len() as u32, requests, - max_tokens: max_total_tokens, - } - } + max_tokens: 0, + }; - #[instrument(skip_all)] - fn get_random_input( - &mut self, - input_length: u32, - seq_bucket_size: u32, - ) -> String { - let skip_tokenizer_in_tgi: bool = env::var("SKIP_TOKENIZER_IN_TGI") - .ok() - .map_or(false, |value| value.to_lowercase() == "true"); - if skip_tokenizer_in_tgi { - // generate random tokens - let mut rng = rand::thread_rng(); - let range = Uniform::new(2, 8192); - let tokens = if input_length % seq_bucket_size == 0 { - input_length - seq_bucket_size / 2 - } else { - input_length - (input_length % seq_bucket_size) / 2 - }; - (0..tokens) - .map(|_| rng.sample(&range).to_string()) - .collect::>() - .join(", ") - } else { - // repeat test string to get expected input shape - let mut bucket_id = input_length / seq_bucket_size; - if input_length % seq_bucket_size != 0 { - bucket_id += 1 - } - let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2); - "_test ".to_string().repeat(repeats as usize) - } + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) } /// Generate one token for each request in the given batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index fdd84035..e1e52d59 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,3 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - use crate::client::{DecodeTimings, PrefillTimings}; /// Multi shard Client use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; @@ -100,7 +98,6 @@ impl ShardedClient { max_prefill_tokens: u32, max_total_tokens: u32, max_batch_size: Option, - model_id: &str, ) -> Result> { let futures: Vec<_> = self .clients @@ -111,7 +108,6 @@ impl ShardedClient { max_prefill_tokens, max_total_tokens, max_batch_size, - model_id )) }) .collect(); diff --git a/router/src/infer.rs b/router/src/infer.rs index 04328246..eef42989 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,5 +1,3 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ @@ -65,22 +63,13 @@ impl Infer { max_batch_size: Option, max_concurrent_requests: usize, requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, window_size: Option, speculate: u32, generation_health: Arc, tokenizer_config: HubTokenizerConfig, ) -> Self { // Infer shared state - let queue = Queue::new( - requires_padding, - max_input_length, - max_total_tokens, - 16, - window_size, - speculate - ); + let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/main.rs b/router/src/main.rs index 4f9f0f73..b11c4526 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,3 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - use axum::http::HeaderValue; use clap::Parser; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; @@ -10,7 +8,6 @@ use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; -use std::env; use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -148,25 +145,6 @@ async fn main() -> Result<(), RouterError> { } } - let (max_batch_size, max_batch_total_tokens) = match (max_batch_size, max_batch_total_tokens) { - (Some(_max_batch_size), Some(_max_batch_total_tokens)) => { - if (_max_batch_total_tokens as usize / max_total_tokens) != _max_batch_size { - tracing::warn!("max_batch_size was set to {_max_batch_size} while max_batch_total_tokens to {_max_batch_total_tokens}"); - tracing::warn!("These values are not match, so max_batch_size will be preferred"); - (Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32)) - } else { - (Some(_max_batch_size), Some(_max_batch_total_tokens)) - } - }, - (Some(_max_batch_size), None) => ( - Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32) - ), - (None, Some(_max_batch_total_tokens)) => ( - Some(_max_batch_total_tokens as usize / max_total_tokens), Some(_max_batch_total_tokens) - ), - (None, None) => (None, None), - }; - // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -228,9 +206,6 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer and model info - let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI") - .ok() - .map_or(false, |value| value.to_lowercase() == "true"); let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api { Type::None => ( Some(local_path.join("tokenizer.json")), @@ -279,11 +254,8 @@ async fn main() -> Result<(), RouterError> { ) } }; - let tokenizer: Option = if skip_tokenizer_in_tgi { - None - } else { - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) - }; + let tokenizer: Option = + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) .ok() @@ -349,7 +321,6 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_total_tokens as u32, max_batch_size, - &model_info.model_id ) .await .map_err(RouterError::Warmup)? diff --git a/router/src/queue.rs b/router/src/queue.rs index 11690bf7..a32673dd 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -1,14 +1,9 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; -use std::cmp::{Eq, Ord, PartialEq, PartialOrd}; -use std::collections::BinaryHeap; -use std::env; -use std::time::Duration; +use std::collections::VecDeque; use text_generation_client::{Batch, Request}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; @@ -41,8 +36,6 @@ pub(crate) struct Queue { impl Queue { pub(crate) fn new( requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, block_size: u32, window_size: Option, speculate: u32, @@ -53,8 +46,6 @@ impl Queue { // Launch background queue task tokio::spawn(queue_task( requires_padding, - max_input_length, - max_total_tokens, block_size, window_size, speculate, @@ -106,21 +97,12 @@ impl Queue { // Background task responsible of the queue state async fn queue_task( requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, block_size: u32, window_size: Option, speculate: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new( - requires_padding, - max_input_length, - max_total_tokens, - block_size, - window_size, - speculate - ); + let mut state = State::new(requires_padding, block_size, window_size, speculate); while let Some(cmd) = receiver.recv().await { match cmd { @@ -145,104 +127,11 @@ async fn queue_task( } } -#[derive(Debug)] -struct IdentifiableEntry(u64, Entry); - -impl Eq for IdentifiableEntry {} - -impl PartialEq for IdentifiableEntry { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Ord for IdentifiableEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - let ordering = match self - .1 - .request - .input_length - .cmp(&other.1.request.input_length) - { - std::cmp::Ordering::Equal => self.0.cmp(&other.0), - any => any, - }; - - // inverse to get min heap - return ordering.reverse(); - } -} - -impl PartialOrd for IdentifiableEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[derive(Debug)] -struct QueueImpl { - regular_entries: BinaryHeap, - overdue_entries: BinaryHeap, - overdue_threshold: Duration, -} - -impl QueueImpl { - fn new(capacity: usize, overdue_threshold: Duration) -> Self { - Self { - regular_entries: BinaryHeap::with_capacity(capacity), - overdue_entries: BinaryHeap::with_capacity(capacity), - overdue_threshold, - } - } - - fn update(&mut self) { - if self.regular_entries.is_empty() { - return; - } - - let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity()); - - for entry in self.regular_entries.drain() { - if entry.1.queue_time.elapsed() > self.overdue_threshold { - self.overdue_entries.push(entry); - } else { - left.push(entry); - } - } - - self.regular_entries = left; - } - - fn push(&mut self, entry: IdentifiableEntry) { - if entry.1.queue_time.elapsed() > self.overdue_threshold { - self.overdue_entries.push(entry); - } else { - self.regular_entries.push(entry); - } - } - - fn pop(&mut self) -> Option { - if !self.overdue_entries.is_empty() { - self.overdue_entries.pop() - } else { - self.regular_entries.pop() - } - } - - fn is_empty(&self) -> bool { - self.regular_entries.is_empty() && self.overdue_entries.is_empty() - } - - fn len(&self) -> usize { - self.regular_entries.len() + self.overdue_entries.len() - } -} - /// Queue State #[derive(Debug)] struct State { - /// Queue entries - entries: QueueImpl, + /// Queue entries organized in a Vec + entries: VecDeque<(u64, Entry)>, /// Id of the next entry next_id: u64, @@ -253,12 +142,6 @@ struct State { /// Whether the model is using padding requires_padding: bool, - /// Maximum input length, required for padding scenario - max_input_length: u32, - - /// Maximum input and output length, required for padding scenario - max_total_tokens: u32, - /// Paged Attention block size block_size: u32, @@ -272,25 +155,15 @@ struct State { impl State { fn new( requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, block_size: u32, window_size: Option, speculate: u32, ) -> Self { - let default_threshold: u64 = 120; - let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") { - Ok(val) => val.parse().unwrap_or(default_threshold), - Err(_) => default_threshold, - }; - Self { - entries: QueueImpl::new(128, Duration::from_millis(threshold)), + entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, - max_input_length, - max_total_tokens, block_size, window_size, speculate, @@ -304,7 +177,7 @@ impl State { entry.temp_span = Some(queue_span); // Push entry in the queue - self.entries.push(IdentifiableEntry(self.next_id, entry)); + self.entries.push_back((self.next_id, entry)); self.next_id += 1; } @@ -329,7 +202,9 @@ impl State { } } - self.entries.update(); + // Pad prefill_token_budget to be a multiple of block size + let prefill_token_budget = + ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); @@ -339,11 +214,12 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; // Pop entries starting from the front of the queue - while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() { + while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -355,7 +231,8 @@ impl State { if self.requires_padding { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length; + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { // pad to block size prefill_tokens += ((entry.request.input_length + self.block_size - 1) @@ -364,9 +241,7 @@ impl State { } if self.requires_padding { - // We pad to max total tokens in the Python shards - // We need to take these padding tokens into the equation - decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length); + decode_tokens += entry.request.stopping_parameters.max_new_tokens; } else { let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, @@ -387,7 +262,7 @@ impl State { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push(IdentifiableEntry(id, entry)); + self.entries.push_front((id, entry)); break; } @@ -434,7 +309,7 @@ impl State { for r in batch_requests.into_iter().rev() { let id = r.id; let entry = batch_entries.remove(&id).unwrap(); - self.entries.push(IdentifiableEntry(id, entry)); + self.entries.push_front((id, entry)); } return None; @@ -483,18 +358,6 @@ mod tests { }; use tracing::info_span; - fn default_queue() -> Queue { - Queue::new( - true, 1, 2, 1, None, 0 - ) - } - - fn default_state() -> State { - State::new( - true, 1, 2, 1, None, 0 - ) - } - fn default_entry() -> ( Entry, mpsc::UnboundedReceiver>, @@ -538,7 +401,7 @@ mod tests { #[test] fn test_append() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -548,13 +411,13 @@ mod tests { assert_eq!(state.next_id, 1); assert_eq!(state.entries.len(), 1); - let id = state.entries.pop().unwrap().0; + let (id, _) = state.entries.remove(0).unwrap(); assert_eq!(id, 0); } #[test] fn test_next_batch_empty() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); assert!(state.next_batch(None, None, 1, 1).is_none()); assert!(state.next_batch(Some(1), None, 1, 1).is_none()); @@ -562,13 +425,13 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 4).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -588,13 +451,13 @@ mod tests { assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); - let IdentifiableEntry(id, _) = state.entries.pop().unwrap(); + let (id, _) = state.entries.remove(0).unwrap(); assert_eq!(id, 2); } #[test] fn test_next_batch_max_size() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -614,13 +477,13 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -633,7 +496,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 6).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -647,14 +510,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -662,13 +525,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, None, 2, 4).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -685,7 +548,7 @@ mod tests { // Not enough token budget assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 4).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -695,7 +558,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -711,13 +574,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, None, 1, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -726,7 +589,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, None, 3, 6).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -736,7 +599,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(true, 1, 2, 1, None, 2); + let queue = Queue::new(false, 1, None, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -755,7 +618,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index 1edcc472..e7570ded 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,5 +1,3 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - use crate::config::Config; /// HTTP Server logic use crate::health::Health; @@ -1493,8 +1491,6 @@ pub async fn run( max_batch_size, max_concurrent_requests, shard_info.requires_padding, - max_input_length as u32, - max_total_tokens as u32, shard_info.window_size, shard_info.speculate, generation_health, diff --git a/router/src/validation.rs b/router/src/validation.rs index ee48f705..96b6cb27 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,12 +1,9 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - use crate::config::Config; /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; -use std::{cmp, env}; use serde_json::Value; use std::io::Cursor; use text_generation_client::{ @@ -34,7 +31,6 @@ pub struct Validation { disable_grammar_support: bool, /// Channel to communicate with the background tokenization task sender: Option>, - skip_tokenizer_in_tgi: bool, } impl Validation { @@ -77,10 +73,6 @@ impl Validation { None }; - let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI") - .ok() - .map_or(false, |value| value.to_lowercase() == "true"); - Self { max_best_of, sender, @@ -89,7 +81,6 @@ impl Validation { max_input_length, max_total_tokens, disable_grammar_support, - skip_tokenizer_in_tgi, } } @@ -128,13 +119,10 @@ impl Validation { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel - let input_length = if self.skip_tokenizer_in_tgi { - inputs.chars().filter(|&c| c == ',').count() + 1 + let input_length = if let Some(truncate) = truncate { + std::cmp::min(encoding.len(), truncate) } else { - cmp::max( - encoding.len(), - truncate.unwrap_or(self.max_input_length) - ) + encoding.len() }; // Get total tokens @@ -177,18 +165,17 @@ impl Validation { } else { return Err(ValidationError::UnsetMaxNewTokens); }; - let input_length = if self.skip_tokenizer_in_tgi { - inputs.chars().filter(|&c| c == ',').count() + 1 - } else { - truncate.unwrap_or(self.max_input_length) - }; + let mut input_length = truncate.unwrap_or(self.max_input_length); + // We don't have a tokenizer, therefore we have no idea how long is the query, let + // them through and hope for the best. // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { - return Err(ValidationError::MaxNewTokens( - self.max_total_tokens - self.max_input_length, - max_new_tokens, - )); + input_length = input_length.saturating_sub(max_new_tokens as usize); + // return Err(ValidationError::MaxNewTokens( + // self.max_total_tokens - self.max_input_length, + // max_new_tokens, + // )); } Ok((inputs, input_length, max_new_tokens)) @@ -248,14 +235,6 @@ impl Validation { return Err(ValidationError::FrequencyPenalty); } - // TODO: enable watermark with fp8 quantization - let quantization_enabled = env::var("QUANT_CONFIG") - .ok() - .map_or(false, |value| !value.is_empty()); - if watermark && quantization_enabled { - return Err(ValidationError::WatermarkWithQuantization); - } - // Different because the proto default value is not a valid value // for the user let top_p = top_p @@ -730,8 +709,6 @@ pub enum ValidationError { InvalidImageContent(String), #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), - #[error("`watermark` = true is not allowed with FP8 quantization.")] - WatermarkWithQuantization, } #[cfg(test)] @@ -768,7 +745,8 @@ mod tests { .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { - Err(ValidationError::MaxNewTokens(1, 10)) => (), + // Err(ValidationError::MaxNewTokens(1, 10)) => (), + Ok((_s, 0, 10)) => (), r => panic!("Unexpected not max new tokens: {r:?}"), } } @@ -801,7 +779,7 @@ mod tests { .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { - Err(ValidationError::MaxTotalTokens(6, 5, 10)) => (), + Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), _ => panic!("Unexpected not max new tokens"), } } diff --git a/server/Makefile b/server/Makefile index 7e38eb12..32d01709 100644 --- a/server/Makefile +++ b/server/Makefile @@ -19,18 +19,12 @@ gen-server: install: gen-server pip install pip --upgrade - pip install -r requirements.txt - pip install -e "." + pip install -r requirements_cuda.txt + pip install -e ".[bnb, accelerate, quantize, peft, outlines]" run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded -install-poetry: - curl -sSL https://install.python-poetry.org | python3 - - -update-lock: - rm poetry.lock - poetry lock --no-update - export-requirements: - poetry export -o requirements.txt --without-hashes + poetry export -o requirements_cuda.txt --without-hashes + poetry export -o requirements_rocm.txt --without-hashes diff --git a/server/dill-0.3.7-patch.sh b/server/dill-0.3.7-patch.sh deleted file mode 100644 index ad8c8be5..00000000 --- a/server/dill-0.3.7-patch.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git -pushd dill -cat < dill-0.3.7.patch -diff --git a/dill/_dill.py b/dill/_dill.py -index d0cf543..f6eb662 100644 ---- a/dill/_dill.py -+++ b/dill/_dill.py -@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered - XRangeType = range - from types import MappingProxyType as DictProxyType, new_class - from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError --import __main__ as _main_module -+class _LazyMainModule(object): -+ _module = None -+ @property -+ def module(self): -+ if self._module is None: -+ import __main__ as _m_module -+ self._module = _m_module -+ return self._module -+_main_module = _LazyMainModule() - import marshal - import gc - # import zlib -@@ -353,7 +361,7 @@ class Pickler(StockPickler): - _fmode = kwds.pop('fmode', None) - _recurse = kwds.pop('recurse', None) - StockPickler.__init__(self, file, *args, **kwds) -- self._main = _main_module -+ self._main = _main_module.module - self._diff_cache = {} - self._byref = settings['byref'] if _byref is None else _byref - self._strictio = False #_strictio -@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler): - settings = Pickler.settings - _ignore = kwds.pop('ignore', None) - StockUnpickler.__init__(self, *args, **kwds) -- self._main = _main_module -+ self._main = _main_module.module - self._ignore = settings['ignore'] if _ignore is None else _ignore - - def load(self): #NOTE: if settings change, need to update attributes - obj = StockUnpickler.load(self) -- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): -+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): - if not self._ignore: - # point obj class to main - try: obj.__class__ = getattr(self._main, type(obj).__name__) -@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj): - logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) - logger.trace(pickler, "# D1") -- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): -+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): - logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? - logger.trace(pickler, "# D3") -- elif '__name__' in obj and obj != _main_module.__dict__ \\ -+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ - and type(obj['__name__']) is str \\ - and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): - logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj -diff --git a/dill/session.py b/dill/session.py -index 74234ab..1be8d89 100644 ---- a/dill/session.py -+++ b/dill/session.py -@@ -233,7 +233,7 @@ def dump_module( - protocol = settings['protocol'] - main = module - if main is None: -- main = _main_module -+ main = _main_module.module - elif isinstance(main, str): - main = _import_module(main) - if not isinstance(main, ModuleType): -@@ -501,7 +501,7 @@ def load_module( - pass - assert loaded is main - _restore_modules(unpickler, main) -- if main is _main_module or main is module: -+ if main is _main_module.module or main is module: - return None - else: - return main - -EOF -git apply dill-0.3.7.patch -python -m pip install . -popd -rm -fr dill diff --git a/server/dill-0.3.8-patch.sh b/server/dill-0.3.8-patch.sh deleted file mode 100644 index da263960..00000000 --- a/server/dill-0.3.8-patch.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -git clone -b 0.3.8 https://github.com/uqfoundation/dill.git -pushd dill -cat < dill-0.3.8.patch -diff --git a/dill/_dill.py b/dill/_dill.py -index d42432f..1d251e6 100644 ---- a/dill/_dill.py -+++ b/dill/_dill.py -@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered - XRangeType = range - from types import MappingProxyType as DictProxyType, new_class - from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError --import __main__ as _main_module -+class _LazyMainModule(object): -+ _module = None -+ @property -+ def module(self): -+ if self._module is None: -+ import __main__ as _m_module -+ self._module = _m_module -+ return self._module -+_main_module = _LazyMainModule() - import marshal - import gc - # import zlib -@@ -355,7 +363,7 @@ class Pickler(StockPickler): - _fmode = kwds.pop('fmode', None) - _recurse = kwds.pop('recurse', None) - StockPickler.__init__(self, file, *args, **kwds) -- self._main = _main_module -+ self._main = _main_module.module - self._diff_cache = {} - self._byref = settings['byref'] if _byref is None else _byref - self._strictio = False #_strictio -@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler): - settings = Pickler.settings - _ignore = kwds.pop('ignore', None) - StockUnpickler.__init__(self, *args, **kwds) -- self._main = _main_module -+ self._main = _main_module.module - self._ignore = settings['ignore'] if _ignore is None else _ignore - - def load(self): #NOTE: if settings change, need to update attributes - obj = StockUnpickler.load(self) -- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): -+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): - if not self._ignore: - # point obj class to main - try: obj.__class__ = getattr(self._main, type(obj).__name__) -@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj): - logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) - logger.trace(pickler, "# D1") -- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): -+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): - logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? - logger.trace(pickler, "# D3") -- elif '__name__' in obj and obj != _main_module.__dict__ \\ -+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ - and type(obj['__name__']) is str \\ - and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): - logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj -diff --git a/dill/session.py b/dill/session.py -index e91068a..a921b43 100644 ---- a/dill/session.py -+++ b/dill/session.py -@@ -233,7 +233,7 @@ def dump_module( - protocol = settings['protocol'] - main = module - if main is None: -- main = _main_module -+ main = _main_module.module - elif isinstance(main, str): - main = _import_module(main) - if not isinstance(main, ModuleType): -@@ -501,7 +501,7 @@ def load_module( - pass - assert loaded is main - _restore_modules(unpickler, main) -- if main is _main_module or main is module: -+ if main is _main_module.module or main is module: - return None - else: - return main - -EOF -git apply dill-0.3.8.patch -python -m pip install . -popd -rm -fr dill \ No newline at end of file diff --git a/server/poetry.lock b/server/poetry.lock index 8dfc122c..2bf4ca22 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,19 +1,19 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "accelerate" -version = "0.33.0" +version = "0.29.3" description = "Accelerate" -optional = false +optional = true python-versions = ">=3.8.0" files = [ - {file = "accelerate-0.33.0-py3-none-any.whl", hash = "sha256:0a7f33d60ba09afabd028d4f0856dd19c5a734b7a596d637d9dd6e3d0eadbaf3"}, - {file = "accelerate-0.33.0.tar.gz", hash = "sha256:11ba481ed6ea09191775df55ce464aeeba67a024bd0261a44b77b30fb439e26a"}, + {file = "accelerate-0.29.3-py3-none-any.whl", hash = "sha256:99d633d4b6126817c5e554487406748be95c8d1d1e659dd2fd60657e35f532dd"}, + {file = "accelerate-0.29.3.tar.gz", hash = "sha256:1a5a845b06b24b41736b219b2b20fd021ca5dff4070a252445fd6de736e347ac"}, ] [package.dependencies] -huggingface-hub = ">=0.21.0" -numpy = ">=1.17,<2.0.0" +huggingface-hub = "*" +numpy = ">=1.17" packaging = ">=20.0" psutil = "*" pyyaml = "*" @@ -21,129 +21,101 @@ safetensors = ">=0.3.1" torch = ">=1.10.0" [package.extras] -deepspeed = ["deepspeed (<=0.14.0)"] -dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "diffusers", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"] rich = ["rich"] sagemaker = ["sagemaker"] -test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] -testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] - -[[package]] -name = "aiohappyeyeballs" -version = "2.4.0" -description = "Happy Eyeballs for asyncio" -optional = false -python-versions = ">=3.8" -files = [ - {file = "aiohappyeyeballs-2.4.0-py3-none-any.whl", hash = "sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd"}, - {file = "aiohappyeyeballs-2.4.0.tar.gz", hash = "sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2"}, -] +testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] [[package]] name = "aiohttp" -version = "3.10.5" +version = "3.9.5" description = "Async http client/server framework (asyncio)" -optional = false +optional = true python-versions = ">=3.8" files = [ - {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, - {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, - {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, - {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, - {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, - {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, - {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, - {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, - {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, - {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, - {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, - {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, - {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, - {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, - {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, - {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, - {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, - {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, - {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, - {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, - {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, - {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, - {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, - {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, - {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, - {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, - {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, - {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, - {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, - {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, - {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"}, + {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"}, + {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"}, + {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"}, + {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"}, + {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"}, + {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:694d828b5c41255e54bc2dddb51a9f5150b4eefa9886e38b52605a05d96566e8"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4558e5012ee03d2638c681e156461d37b7a113fe13970d438d95d10173d25f78"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dbc053ac75ccc63dc3a3cc547b98c7258ec35a215a92bd9f983e0aac95d3d5b"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4109adee842b90671f1b689901b948f347325045c15f46b39797ae1bf17019de"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6ea1a5b409a85477fd8e5ee6ad8f0e40bf2844c270955e09360418cfd09abac"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3c2890ca8c59ee683fd09adf32321a40fe1cf164e3387799efb2acebf090c11"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3916c8692dbd9d55c523374a3b8213e628424d19116ac4308e434dbf6d95bbdd"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8d1964eb7617907c792ca00b341b5ec3e01ae8c280825deadbbd678447b127e1"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5ab8e1f6bee051a4bf6195e38a5c13e5e161cb7bad83d8854524798bd9fcd6e"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:52c27110f3862a1afbcb2af4281fc9fdc40327fa286c4625dfee247c3ba90156"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:7f64cbd44443e80094309875d4f9c71d0401e966d191c3d469cde4642bc2e031"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b4f72fbb66279624bfe83fd5eb6aea0022dad8eec62b71e7bf63ee1caadeafe"}, + {file = "aiohttp-3.9.5-cp38-cp38-win32.whl", hash = "sha256:6380c039ec52866c06d69b5c7aad5478b24ed11696f0e72f6b807cfb261453da"}, + {file = "aiohttp-3.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:da22dab31d7180f8c3ac7c7635f3bcd53808f374f6aa333fe0b0b9e14b01f91a"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"}, + {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"}, + {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"}, + {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"}, ] [package.dependencies] -aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" @@ -152,13 +124,13 @@ multidict = ">=4.5,<7.0" yarl = ">=1.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +speedups = ["Brotli", "aiodns", "brotlicffi"] [[package]] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, @@ -183,7 +155,7 @@ files = [ name = "async-timeout" version = "4.0.3" description = "Timeout context manager for asyncio programs" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, @@ -192,22 +164,22 @@ files = [ [[package]] name = "attrs" -version = "24.2.0" +version = "23.2.0" description = "Classes Without Boilerplate" -optional = false +optional = true python-versions = ">=3.7" files = [ - {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, - {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] [[package]] name = "backoff" @@ -220,15 +192,34 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "bitsandbytes" +version = "0.43.1" +description = "k-bit optimizers and matrix multiplication routines." +optional = true +python-versions = "*" +files = [ + {file = "bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a81c826d576d6d691c7b4a7491c8fdc0f37f769795d6ca2e54afa605d2c260a3"}, + {file = "bitsandbytes-0.43.1-py3-none-win_amd64.whl", hash = "sha256:52c1c7189a6ca006555a9663e544e75f40520a97a26e075411f9f9aca0771fcd"}, +] + +[package.dependencies] +numpy = "*" +torch = "*" + +[package.extras] +benchmark = ["matplotlib", "pandas"] +test = ["scipy"] + [[package]] name = "certifi" -version = "2024.7.4" +version = "2024.2.2" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, - {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, + {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, + {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, ] [[package]] @@ -366,66 +357,47 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "coloredlogs" -version = "15.0.1" -description = "Colored terminal output for Python's logging module" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -files = [ - {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, - {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"}, -] - -[package.dependencies] -humanfriendly = ">=9.1" - -[package.extras] -cron = ["capturer (>=2.4)"] - [[package]] name = "datasets" -version = "2.21.0" +version = "2.14.4" description = "HuggingFace community-driven open-source library of datasets" -optional = false +optional = true python-versions = ">=3.8.0" files = [ - {file = "datasets-2.21.0-py3-none-any.whl", hash = "sha256:25e4e097110ce28824b746a107727ada94024cba11db8bc588d468414692b65a"}, - {file = "datasets-2.21.0.tar.gz", hash = "sha256:998f85a8460f1bd982e5bd058f8a0808eef424249e3df1e8cdd594ccd0dc8ba2"}, + {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, + {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, ] [package.dependencies] aiohttp = "*" -dill = ">=0.3.0,<0.3.9" -filelock = "*" -fsspec = {version = ">=2023.1.0,<=2024.6.1", extras = ["http"]} -huggingface-hub = ">=0.21.2" +dill = ">=0.3.0,<0.3.8" +fsspec = {version = ">=2021.11.1", extras = ["http"]} +huggingface-hub = ">=0.14.0,<1.0.0" multiprocess = "*" numpy = ">=1.17" packaging = "*" pandas = "*" -pyarrow = ">=15.0.0" +pyarrow = ">=8.0.0" pyyaml = ">=5.1" -requests = ">=2.32.2" -tqdm = ">=4.66.3" +requests = ">=2.19.0" +tqdm = ">=4.62.1" xxhash = "*" [package.extras] -apache-beam = ["apache-beam (>=2.26.0)"] -audio = ["librosa", "soundfile (>=0.12.1)", "soxr (>=0.4.0)"] +apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0)", "tensorflow (>=2.6.0)", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "transformers (>=4.42.0)", "typing-extensions (>=4.6.1)", "zstandard"] -docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] -jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] -metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk (<3.8.2)", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] -quality = ["ruff (>=0.3.0)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] s3 = ["s3fs"] -tensorflow = ["tensorflow (>=2.6.0)"] -tensorflow-gpu = ["tensorflow (>=2.6.0)"] -tests = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0)", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers (>=4.42.0)", "typing-extensions (>=4.6.1)", "zstandard"] -tests-numpy2 = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tiktoken", "torch (>=2.0.0)", "typing-extensions (>=4.6.1)", "zstandard"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] torch = ["torch"] -vision = ["Pillow (>=9.4.0)"] +vision = ["Pillow (>=6.2.1)"] [[package]] name = "deprecated" @@ -444,50 +416,19 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] -[[package]] -name = "diffusers" -version = "0.29.2" -description = "State-of-the-art diffusion in PyTorch and JAX." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "diffusers-0.29.2-py3-none-any.whl", hash = "sha256:d5e9bb13c8097b4eed10df23d1294d2e5a418f53e3f89c7ef228b5b982970428"}, - {file = "diffusers-0.29.2.tar.gz", hash = "sha256:b85f277668e22089cf68b40dd9b76940db7d24ba9cdac107533ed10ab8e4e9db"}, -] - -[package.dependencies] -filelock = "*" -huggingface-hub = ">=0.23.2" -importlib-metadata = "*" -numpy = "*" -Pillow = "*" -regex = "!=2019.12.17" -requests = "*" -safetensors = ">=0.3.1" - -[package.extras] -dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.29.3)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.25.1)", "urllib3 (<=2.0.0)"] -docs = ["hf-doc-builder (>=0.3.0)"] -flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"] -quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"] -test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.25.1)"] -torch = ["accelerate (>=0.29.3)", "torch (>=1.4)"] -training = ["Jinja2", "accelerate (>=0.29.3)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] - [[package]] name = "dill" -version = "0.3.8" +version = "0.3.7" description = "serialize all of Python" -optional = false -python-versions = ">=3.8" +optional = true +python-versions = ">=3.7" files = [ - {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, - {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, + {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, + {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, ] [package.extras] graph = ["objgraph (>=1.7.2)"] -profile = ["gprof2dot (>=2022.7.29)"] [[package]] name = "diskcache" @@ -500,15 +441,26 @@ files = [ {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, ] +[[package]] +name = "einops" +version = "0.6.1" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.7" +files = [ + {file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"}, + {file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"}, +] + [[package]] name = "exceptiongroup" -version = "1.2.2" +version = "1.2.1" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, - {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, + {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, + {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, ] [package.extras] @@ -516,25 +468,25 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.15.4" +version = "3.14.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, - {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, + {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, + {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] name = "frozenlist" version = "1.4.1" description = "A list-like structure which implements collections.abc.MutableSequence" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, @@ -618,13 +570,13 @@ files = [ [[package]] name = "fsspec" -version = "2024.6.1" +version = "2024.5.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"}, - {file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"}, + {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"}, + {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"}, ] [package.dependencies] @@ -636,7 +588,6 @@ adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] dev = ["pre-commit", "ruff"] -doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] @@ -660,17 +611,17 @@ tqdm = ["tqdm"] [[package]] name = "googleapis-common-protos" -version = "1.63.2" +version = "1.63.0" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" files = [ - {file = "googleapis-common-protos-1.63.2.tar.gz", hash = "sha256:27c5abdffc4911f28101e635de1533fb4cfd2c37fbaa9174587c799fac90aa87"}, - {file = "googleapis_common_protos-1.63.2-py2.py3-none-any.whl", hash = "sha256:27a2499c7e8aff199665b22741997e485eccc8645aa9176c7c988e6fae507945"}, + {file = "googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e"}, + {file = "googleapis_common_protos-1.63.0-py2.py3-none-any.whl", hash = "sha256:ae45f75702f7c08b541f750854a678bd8f534a1a6bace6afe975f1d0a82d6632"}, ] [package.dependencies] -protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] @@ -694,226 +645,242 @@ testing = ["protobuf (>=4.21.9)"] [[package]] name = "grpcio" -version = "1.66.0" +version = "1.64.0" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.66.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:ad7256f224437b2c29c2bef98ddd3130454c5b1ab1f0471fc11794cefd4dbd3d"}, - {file = "grpcio-1.66.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5f4b3357e59dfba9140a51597287297bc638710d6a163f99ee14efc19967a821"}, - {file = "grpcio-1.66.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:e8d20308eeae15b3e182f47876f05acbdec1eebd9473a9814a44e46ec4a84c04"}, - {file = "grpcio-1.66.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1eb03524d0f55b965d6c86aa44e5db9e5eaa15f9ed3b164621e652e5b927f4b8"}, - {file = "grpcio-1.66.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37514b68a42e9cf24536345d3cf9e580ffd29117c158b4eeea34625200256067"}, - {file = "grpcio-1.66.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:516fdbc8e156db71a004bc431a6303bca24cfde186babe96dde7bd01e8f0cc70"}, - {file = "grpcio-1.66.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d0439a970d65327de21c299ea0e0c2ad0987cdaf18ba5066621dea5f427f922b"}, - {file = "grpcio-1.66.0-cp310-cp310-win32.whl", hash = "sha256:5f93fc84b72bbc7b84a42f3ca9dc055fa00d2303d9803be011ebf7a10a4eb833"}, - {file = "grpcio-1.66.0-cp310-cp310-win_amd64.whl", hash = "sha256:8fc5c710ddd51b5a0dc36ef1b6663430aa620e0ce029b87b150dafd313b978c3"}, - {file = "grpcio-1.66.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:dd614370e939f9fceeeb2915111a0795271b4c11dfb5fc0f58449bee40c726a5"}, - {file = "grpcio-1.66.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:245b08f9b3c645a6a623f3ed4fa43dcfcd6ad701eb9c32511c1bb7380e8c3d23"}, - {file = "grpcio-1.66.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:aaf30c75cbaf30e561ca45f21eb1f729f0fab3f15c592c1074795ed43e3ff96f"}, - {file = "grpcio-1.66.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49234580a073ce7ac490112f6c67c874cbcb27804c4525978cdb21ba7f3f193c"}, - {file = "grpcio-1.66.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de9e20a0acb709dcfa15a622c91f584f12c9739a79c47999f73435d2b3cc8a3b"}, - {file = "grpcio-1.66.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bc008c6afa1e7c8df99bd9154abc4f0470d26b7730ca2521122e99e771baa8c7"}, - {file = "grpcio-1.66.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:50cea8ce2552865b87e3dffbb85eb21e6b98d928621600c0feda2f02449cd837"}, - {file = "grpcio-1.66.0-cp311-cp311-win32.whl", hash = "sha256:508411df1f2b7cfa05d4d7dbf3d576fe4f949cd61c03f3a6f0378c84e3d7b963"}, - {file = "grpcio-1.66.0-cp311-cp311-win_amd64.whl", hash = "sha256:6d586a95c05c82a5354be48bb4537e1accaf2472d8eb7e9086d844cbff934482"}, - {file = "grpcio-1.66.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:5ea27f4ce8c0daccfdd2c7961e6ba404b6599f47c948415c4cca5728739107a3"}, - {file = "grpcio-1.66.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:296a45ea835e12a1cc35ab0c57e455346c272af7b0d178e29c67742167262b4c"}, - {file = "grpcio-1.66.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:e36fa838ac1d6c87198ca149cbfcc92e1af06bb8c8cd852622f8e58f33ea3324"}, - {file = "grpcio-1.66.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:684a4c07883cbd4ac864f0d08d927267404f5f0c76f31c85f9bbe05f2daae2f2"}, - {file = "grpcio-1.66.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3084e590e857ba7585ae91078e4c9b6ef55aaf1dc343ce26400ba59a146eada"}, - {file = "grpcio-1.66.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:526d4f6ca19f31b25606d5c470ecba55c0b22707b524e4de8987919e8920437d"}, - {file = "grpcio-1.66.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:423ae18637cd99ddcf2e5a6851c61828c49e9b9d022d0442d979b4f230109787"}, - {file = "grpcio-1.66.0-cp312-cp312-win32.whl", hash = "sha256:7bc9d823e05d63a87511fb456dcc48dc0fced86c282bf60229675e7ee7aac1a1"}, - {file = "grpcio-1.66.0-cp312-cp312-win_amd64.whl", hash = "sha256:230cdd696751e7eb1395718cd308234749daa217bb8d128f00357dc4df102558"}, - {file = "grpcio-1.66.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:0f3010bf46b2a01c9e40644cb9ed91b4b8435e5c500a275da5f9f62580e31e80"}, - {file = "grpcio-1.66.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ba18cfdc09312eb2eea6fa0ce5d2eec3cf345ea78f6528b2eaed6432105e0bd0"}, - {file = "grpcio-1.66.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:53d4c6706b49e358a2a33345dbe9b6b3bb047cecd7e8c07ba383bd09349bfef8"}, - {file = "grpcio-1.66.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:643d8d9632a688ae69661e924b862e23c83a3575b24e52917ec5bcc59543d212"}, - {file = "grpcio-1.66.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba60ae3b465b3e85080ae3bfbc36fd0305ae495ab16fcf8022fc7d7a23aac846"}, - {file = "grpcio-1.66.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9d5251578767fe44602688c851c2373b5513048ac84c21a0fe946590a8e7933d"}, - {file = "grpcio-1.66.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5e8140b39f10d7be2263afa2838112de29374c5c740eb0afd99146cb5bdbd990"}, - {file = "grpcio-1.66.0-cp38-cp38-win32.whl", hash = "sha256:5b15ef1b296c4e78f15f64fc65bf8081f8774480ffcac45642f69d9d753d9c6b"}, - {file = "grpcio-1.66.0-cp38-cp38-win_amd64.whl", hash = "sha256:c072f90a1f0409f827ae86266984cba65e89c5831a0726b9fc7f4b5fb940b853"}, - {file = "grpcio-1.66.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:a639d3866bfb5a678b5c0b92cd7ab543033ed8988854290fd86145e71731fd4c"}, - {file = "grpcio-1.66.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6ed35bf7da3fb3b1949e32bdf47a8b5ffe0aed11722d948933bd068531cd4682"}, - {file = "grpcio-1.66.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:1c5466222470cb7fbc9cc898af1d48eefd297cb2e2f59af6d4a851c862fa90ac"}, - {file = "grpcio-1.66.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:921b8f7f25d5300d7c6837a1e0639ef145fbdbfb728e0a5db2dbccc9fc0fd891"}, - {file = "grpcio-1.66.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3f6feb0dc8456d025e566709f7dd02885add99bedaac50229013069242a1bfd"}, - {file = "grpcio-1.66.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748452dbd5a047475d5413bdef08b0b9ceb2c0c0e249d4ee905a5fb82c6328dc"}, - {file = "grpcio-1.66.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:832945e64176520520317b50d64ec7d79924429528d5747669b52d0bf2c7bd78"}, - {file = "grpcio-1.66.0-cp39-cp39-win32.whl", hash = "sha256:8096a922eb91bc97c839f675c3efa1257c6ef181ae1b25d3fb97f2cae4c57c01"}, - {file = "grpcio-1.66.0-cp39-cp39-win_amd64.whl", hash = "sha256:375b58892301a5fc6ca7d7ff689c9dc9d00895f5d560604ace9f4f0573013c63"}, - {file = "grpcio-1.66.0.tar.gz", hash = "sha256:c1ea4c528e7db6660718e4165fd1b5ac24b79a70c870a7bc0b7bdb9babab7c1e"}, + {file = "grpcio-1.64.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3b09c3d9de95461214a11d82cc0e6a46a6f4e1f91834b50782f932895215e5db"}, + {file = "grpcio-1.64.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7e013428ab472892830287dd082b7d129f4d8afef49227a28223a77337555eaa"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:02cc9cc3f816d30f7993d0d408043b4a7d6a02346d251694d8ab1f78cc723e7e"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f5de082d936e0208ce8db9095821361dfa97af8767a6607ae71425ac8ace15c"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b7bf346391dffa182fba42506adf3a84f4a718a05e445b37824136047686a1"}, + {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b2cbdfba18408389a1371f8c2af1659119e1831e5ed24c240cae9e27b4abc38d"}, + {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:aca4f15427d2df592e0c8f3d38847e25135e4092d7f70f02452c0e90d6a02d6d"}, + {file = "grpcio-1.64.0-cp310-cp310-win32.whl", hash = "sha256:7c1f5b2298244472bcda49b599be04579f26425af0fd80d3f2eb5fd8bc84d106"}, + {file = "grpcio-1.64.0-cp310-cp310-win_amd64.whl", hash = "sha256:73f84f9e5985a532e47880b3924867de16fa1aa513fff9b26106220c253c70c5"}, + {file = "grpcio-1.64.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2a18090371d138a57714ee9bffd6c9c9cb2e02ce42c681aac093ae1e7189ed21"}, + {file = "grpcio-1.64.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:59c68df3a934a586c3473d15956d23a618b8f05b5e7a3a904d40300e9c69cbf0"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b52e1ec7185512103dd47d41cf34ea78e7a7361ba460187ddd2416b480e0938c"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d598b5d5e2c9115d7fb7e2cb5508d14286af506a75950762aa1372d60e41851"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01615bbcae6875eee8091e6b9414072f4e4b00d8b7e141f89635bdae7cf784e5"}, + {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0b2dfe6dcace264807d9123d483d4c43274e3f8c39f90ff51de538245d7a4145"}, + {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7f17572dc9acd5e6dfd3014d10c0b533e9f79cd9517fc10b0225746f4c24b58e"}, + {file = "grpcio-1.64.0-cp311-cp311-win32.whl", hash = "sha256:6ec5ed15b4ffe56e2c6bc76af45e6b591c9be0224b3fb090adfb205c9012367d"}, + {file = "grpcio-1.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:597191370951b477b7a1441e1aaa5cacebeb46a3b0bd240ec3bb2f28298c7553"}, + {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, + {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, + {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, + {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, + {file = "grpcio-1.64.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c46fb6bfca17bfc49f011eb53416e61472fa96caa0979b4329176bdd38cbbf2a"}, + {file = "grpcio-1.64.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3d2004e85cf5213995d09408501f82c8534700d2babeb81dfdba2a3bff0bb396"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6d5541eb460d73a07418524fb64dcfe0adfbcd32e2dac0f8f90ce5b9dd6c046c"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f279ad72dd7d64412e10f2443f9f34872a938c67387863c4cd2fb837f53e7d2"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85fda90b81da25993aa47fae66cae747b921f8f6777550895fb62375b776a231"}, + {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a053584079b793a54bece4a7d1d1b5c0645bdbee729215cd433703dc2532f72b"}, + {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:579dd9fb11bc73f0de061cab5f8b2def21480fd99eb3743ed041ad6a1913ee2f"}, + {file = "grpcio-1.64.0-cp38-cp38-win32.whl", hash = "sha256:23b6887bb21d77649d022fa1859e05853fdc2e60682fd86c3db652a555a282e0"}, + {file = "grpcio-1.64.0-cp38-cp38-win_amd64.whl", hash = "sha256:753cb58683ba0c545306f4e17dabf468d29cb6f6b11832e1e432160bb3f8403c"}, + {file = "grpcio-1.64.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:2186d76a7e383e1466e0ea2b0febc343ffeae13928c63c6ec6826533c2d69590"}, + {file = "grpcio-1.64.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0f30596cdcbed3c98024fb4f1d91745146385b3f9fd10c9f2270cbfe2ed7ed91"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:d9171f025a196f5bcfec7e8e7ffb7c3535f7d60aecd3503f9e250296c7cfc150"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf4c8daed18ae2be2f1fc7d613a76ee2a2e28fdf2412d5c128be23144d28283d"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3550493ac1d23198d46dc9c9b24b411cef613798dc31160c7138568ec26bc9b4"}, + {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3161a8f8bb38077a6470508c1a7301cd54301c53b8a34bb83e3c9764874ecabd"}, + {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e8fabe2cc57a369638ab1ad8e6043721014fdf9a13baa7c0e35995d3a4a7618"}, + {file = "grpcio-1.64.0-cp39-cp39-win32.whl", hash = "sha256:31890b24d47b62cc27da49a462efe3d02f3c120edb0e6c46dcc0025506acf004"}, + {file = "grpcio-1.64.0-cp39-cp39-win_amd64.whl", hash = "sha256:5a56797dea8c02e7d3a85dfea879f286175cf4d14fbd9ab3ef2477277b927baa"}, + {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.66.0)"] +protobuf = ["grpcio-tools (>=1.64.0)"] [[package]] name = "grpcio-reflection" -version = "1.48.2" +version = "1.62.2" description = "Standard Protobuf Reflection Service for gRPC" optional = false python-versions = ">=3.6" files = [ - {file = "grpcio-reflection-1.48.2.tar.gz", hash = "sha256:b687acc86c736ba8273523e1cdd5f31155dccabf7f9b2acfb62bf4e9c79d3b5a"}, - {file = "grpcio_reflection-1.48.2-py3-none-any.whl", hash = "sha256:280bf4569149126050b587ff9177051a409ee98882028dcf0c9caa3c2d31f6fe"}, + {file = "grpcio-reflection-1.62.2.tar.gz", hash = "sha256:2dd44806d68d0006636529bda573012b19a42281478c2d051cdaaebb91e2516c"}, + {file = "grpcio_reflection-1.62.2-py3-none-any.whl", hash = "sha256:68e8dff3617a9afaf7c462c688f7ca62b55323f497c662abf9965f2953508885"}, ] [package.dependencies] -grpcio = ">=1.48.2" -protobuf = ">=3.12.0" +grpcio = ">=1.62.2" +protobuf = ">=4.21.6" [[package]] name = "grpcio-status" -version = "1.48.2" +version = "1.62.2" description = "Status proto mapping for gRPC" optional = false python-versions = ">=3.6" files = [ - {file = "grpcio-status-1.48.2.tar.gz", hash = "sha256:53695f45da07437b7c344ee4ef60d370fd2850179f5a28bb26d8e2aa1102ec11"}, - {file = "grpcio_status-1.48.2-py3-none-any.whl", hash = "sha256:2c33bbdbe20188b2953f46f31af669263b6ee2a9b2d38fa0d36ee091532e21bf"}, + {file = "grpcio-status-1.62.2.tar.gz", hash = "sha256:62e1bfcb02025a1cd73732a2d33672d3e9d0df4d21c12c51e0bbcaf09bab742a"}, + {file = "grpcio_status-1.62.2-py3-none-any.whl", hash = "sha256:206ddf0eb36bc99b033f03b2c8e95d319f0044defae9b41ae21408e7e0cda48f"}, ] [package.dependencies] googleapis-common-protos = ">=1.5.5" -grpcio = ">=1.48.2" -protobuf = ">=3.12.0" +grpcio = ">=1.62.2" +protobuf = ">=4.21.6" [[package]] name = "grpcio-tools" -version = "1.48.2" +version = "1.62.2" description = "Protobuf code generator for gRPC" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "grpcio-tools-1.48.2.tar.gz", hash = "sha256:8902a035708555cddbd61b5467cea127484362decc52de03f061a1a520fe90cd"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:92acc3e10ba2b0dcb90a88ae9fe1cc0ffba6868545207e4ff20ca95284f8e3c9"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e5bb396d63495667d4df42e506eed9d74fc9a51c99c173c04395fe7604c848f1"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:84a84d601a238572d049d3108e04fe4c206536e81076d56e623bd525a1b38def"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70564521e86a0de35ea9ac6daecff10cb46860aec469af65869974807ce8e98b"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdbbe63f6190187de5946891941629912ac8196701ed2253fa91624a397822ec"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae56f133b05b7e5d780ef7e032dd762adad7f3dc8f64adb43ff5bfabd659f435"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0feb4f2b777fa6377e977faa89c26359d4f31953de15e035505b92f41aa6906"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-win32.whl", hash = "sha256:80f450272316ca0924545f488c8492649ca3aeb7044d4bf59c426dcdee527f7c"}, - {file = "grpcio_tools-1.48.2-cp310-cp310-win_amd64.whl", hash = "sha256:21ff50e321736eba22210bf9b94e05391a9ac345f26e7df16333dc75d63e74fb"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-linux_armv7l.whl", hash = "sha256:d598ccde6338b2cfbb3124f34c95f03394209013f9b1ed4a5360a736853b1c27"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:a43d26714933f23de93ea0bf9c86c66a6ede709b8ca32e357f9e2181703e64ae"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_aarch64.whl", hash = "sha256:55fdebc73fb580717656b1bafa4f8eca448726a7aa22726a6c0a7895d2f0f088"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8588819b22d0de3aa1951e1991cc3e4b9aa105eecf6e3e24eb0a2fc8ab958b3e"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9771d4d317dca029dfaca7ec9282d8afe731c18bc536ece37fd39b8a974cc331"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d886a9e052a038642b3af5d18e6f2085d1656d9788e202dc23258cf3a751e7ca"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d77e8b1613876e0d8fd17709509d4ceba13492816426bd156f7e88a4c47e7158"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-win32.whl", hash = "sha256:dcaaecdd5e847de5c1d533ea91522bf56c9e6b2dc98cdc0d45f0a1c26e846ea2"}, - {file = "grpcio_tools-1.48.2-cp36-cp36m-win_amd64.whl", hash = "sha256:0119aabd9ceedfdf41b56b9fdc8284dd85a7f589d087f2694d743f346a368556"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:189be2a9b672300ca6845d94016bdacc052fdbe9d1ae9e85344425efae2ff8ef"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:9443f5c30bac449237c3cf99da125f8d6e6c01e17972bc683ee73b75dea95573"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:e0403e095b343431195db1305248b50019ad55d3dd310254431af87e14ef83a2"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5410d6b601d1404835e34466bd8aee37213489b36ee1aad2276366e265ff29d4"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51be91b7c7056ff9ee48b1eccd4a2840b0126230803a5e09dfc082a5b16a91c1"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:516eedd5eb7af6326050bc2cfceb3a977b9cc1144f283c43cc4956905285c912"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d18599ab572b2f15a8f3db49503272d1bb4fcabb4b4d1214ef03aca1816b20a0"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-win32.whl", hash = "sha256:d18ef2adc05a8ef9e58ac46357f6d4ce7e43e077c7eda0a4425773461f9d0e6e"}, - {file = "grpcio_tools-1.48.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6d9753944e5a6b6b78b76ce9d2ae0fe3f748008c1849deb7fadcb64489d6553b"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:3c8749dca04a8d302862ceeb1dfbdd071ee13b281395975f24405a347e5baa57"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:7307dd2408b82ea545ae63502ec03036b025f449568556ea9a056e06129a7a4e"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:072234859f6069dc43a6be8ad6b7d682f4ba1dc2e2db2ebf5c75f62eee0f6dfb"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6cc298fbfe584de8876a85355efbcf796dfbcfac5948c9560f5df82e79336e2a"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75973a42c710999acd419968bc79f00327e03e855bbe82c6529e003e49af660"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f766050e491d0b3203b6b85638015f543816a2eb7d089fc04e86e00f6de0e31d"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8e0d74403484eb77e8df2566a64b8b0b484b5c87903678c381634dd72f252d5e"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-win32.whl", hash = "sha256:cb75bac0cd43858cb759ef103fe68f8c540cb58b63dda127e710228fec3007b8"}, - {file = "grpcio_tools-1.48.2-cp38-cp38-win_amd64.whl", hash = "sha256:cabc8b0905cedbc3b2b7b2856334fa35cce3d4bc79ae241cacd8cca8940a5c85"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:e712a6d00606ad19abdeae852a7e521d6f6d0dcea843708fecf3a38be16a851e"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:e7e7668f89fd598c5469bb58e16bfd12b511d9947ccc75aec94da31f62bc3758"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:a415fbec67d4ff7efe88794cbe00cf548d0f0a5484cceffe0a0c89d47694c491"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d96e96ae7361aa51c9cd9c73b677b51f691f98df6086860fcc3c45852d96b0b0"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e20d7885a40e68a2bda92908acbabcdf3c14dd386c3845de73ba139e9df1f132"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8a5614251c46da07549e24f417cf989710250385e9d80deeafc53a0ee7df6325"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ace0035766fe01a1b096aa050be9f0a9f98402317e7aeff8bfe55349be32a407"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-win32.whl", hash = "sha256:4fa4300b1be59b046492ed3c5fdb59760bc6433f44c08f50de900f9552ec7461"}, - {file = "grpcio_tools-1.48.2-cp39-cp39-win_amd64.whl", hash = "sha256:0fb6c1c1e56eb26b224adc028a4204b6ad0f8b292efa28067dff273bbc8b27c4"}, + {file = "grpcio-tools-1.62.2.tar.gz", hash = "sha256:5fd5e1582b678e6b941ee5f5809340be5e0724691df5299aae8226640f94e18f"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:1679b4903aed2dc5bd8cb22a452225b05dc8470a076f14fd703581efc0740cdb"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:9d41e0e47dd075c075bb8f103422968a65dd0d8dc8613288f573ae91eb1053ba"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:987e774f74296842bbffd55ea8826370f70c499e5b5f71a8cf3103838b6ee9c3"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40cd4eeea4b25bcb6903b82930d579027d034ba944393c4751cdefd9c49e6989"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6746bc823958499a3cf8963cc1de00072962fb5e629f26d658882d3f4c35095"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2ed775e844566ce9ce089be9a81a8b928623b8ee5820f5e4d58c1a9d33dfc5ae"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bdc5dd3f57b5368d5d661d5d3703bcaa38bceca59d25955dff66244dbc987271"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win32.whl", hash = "sha256:3a8d6f07e64c0c7756f4e0c4781d9d5a2b9cc9cbd28f7032a6fb8d4f847d0445"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:e33b59fb3efdddeb97ded988a871710033e8638534c826567738d3edce528752"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:472505d030135d73afe4143b0873efe0dcb385bd6d847553b4f3afe07679af00"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:ec674b4440ef4311ac1245a709e87b36aca493ddc6850eebe0b278d1f2b6e7d1"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:184b4174d4bd82089d706e8223e46c42390a6ebac191073b9772abc77308f9fa"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c195d74fe98541178ece7a50dad2197d43991e0f77372b9a88da438be2486f12"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a34d97c62e61bfe9e6cff0410fe144ac8cca2fc979ad0be46b7edf026339d161"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb8453ae83a1db2452b7fe0f4b78e4a8dd32be0f2b2b73591ae620d4d784d3d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f989e5cebead3ae92c6abf6bf7b19949e1563a776aea896ac5933f143f0c45d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win32.whl", hash = "sha256:c48fabe40b9170f4e3d7dd2c252e4f1ff395dc24e49ac15fc724b1b6f11724da"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:8c616d0ad872e3780693fce6a3ac8ef00fc0963e6d7815ce9dcfae68ba0fc287"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:10cc3321704ecd17c93cf68c99c35467a8a97ffaaed53207e9b2da6ae0308ee1"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:9be84ff6d47fd61462be7523b49d7ba01adf67ce4e1447eae37721ab32464dd8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d82f681c9a9d933a9d8068e8e382977768e7779ddb8870fa0cf918d8250d1532"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04c607029ae3660fb1624ed273811ffe09d57d84287d37e63b5b802a35897329"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72b61332f1b439c14cbd3815174a8f1d35067a02047c32decd406b3a09bb9890"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8214820990d01b52845f9fbcb92d2b7384a0c321b303e3ac614c219dc7d1d3af"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:462e0ab8dd7c7b70bfd6e3195eebc177549ede5cf3189814850c76f9a340d7ce"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win32.whl", hash = "sha256:fa107460c842e4c1a6266150881694fefd4f33baa544ea9489601810c2210ef8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:759c60f24c33a181bbbc1232a6752f9b49fbb1583312a4917e2b389fea0fb0f2"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:45db5da2bcfa88f2b86b57ef35daaae85c60bd6754a051d35d9449c959925b57"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ab84bae88597133f6ea7a2bdc57b2fda98a266fe8d8d4763652cbefd20e73ad7"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7a49bccae1c7d154b78e991885c3111c9ad8c8fa98e91233de425718f47c6139"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7e439476b29d6dac363b321781a113794397afceeb97dad85349db5f1cb5e9a"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ea369c4d1567d1acdf69c8ea74144f4ccad9e545df7f9a4fc64c94fa7684ba3"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f955702dc4b530696375251319d05223b729ed24e8673c2129f7a75d2caefbb"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3708a747aa4b6b505727282ca887041174e146ae030ebcadaf4c1d346858df62"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2ce149ea55eadb486a7fb75a20f63ef3ac065ee6a0240ed25f3549ce7954c653"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:58cbb24b3fa6ae35aa9c210fcea3a51aa5fef0cd25618eb4fd94f746d5a9b703"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:6413581e14a80e0b4532577766cf0586de4dd33766a31b3eb5374a746771c07d"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:47117c8a7e861382470d0e22d336e5a91fdc5f851d1db44fa784b9acea190d87"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f1ba79a253df9e553d20319c615fa2b429684580fa042dba618d7f6649ac7e4"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04a394cf5e51ba9be412eb9f6c482b6270bd81016e033e8eb7d21b8cc28fe8b5"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3c53b221378b035ae2f1881cbc3aca42a6075a8e90e1a342c2f205eb1d1aa6a1"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c384c838b34d1b67068e51b5bbe49caa6aa3633acd158f1ab16b5da8d226bc53"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win32.whl", hash = "sha256:19ea69e41c3565932aa28a202d1875ec56786aea46a2eab54a3b28e8a27f9517"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:1d768a5c07279a4c461ebf52d0cec1c6ca85c6291c71ec2703fe3c3e7e28e8c4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:5b07b5874187e170edfbd7aa2ca3a54ebf3b2952487653e8c0b0d83601c33035"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d58389fe8be206ddfb4fa703db1e24c956856fcb9a81da62b13577b3a8f7fda7"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:7d8b4e00c3d7237b92260fc18a561cd81f1da82e8be100db1b7d816250defc66"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe08d2038f2b7c53259b5c49e0ad08c8e0ce2b548d8185993e7ef67e8592cca"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19216e1fb26dbe23d12a810517e1b3fbb8d4f98b1a3fbebeec9d93a79f092de4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b8574469ecc4ff41d6bb95f44e0297cdb0d95bade388552a9a444db9cd7485cd"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4f6f32d39283ea834a493fccf0ebe9cfddee7577bdcc27736ad4be1732a36399"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win32.whl", hash = "sha256:76eb459bdf3fb666e01883270beee18f3f11ed44488486b61cd210b4e0e17cc1"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:217c2ee6a7ce519a55958b8622e21804f6fdb774db08c322f4c9536c35fdce7c"}, ] [package.dependencies] -grpcio = ">=1.48.2" -protobuf = ">=3.12.0,<4.0dev" +grpcio = ">=1.62.2" +protobuf = ">=4.21.6,<5.0dev" setuptools = "*" [[package]] name = "hf-transfer" -version = "0.1.8" -description = "Speed up file transfers with the Hugging Face Hub." +version = "0.1.6" +description = "" optional = false python-versions = ">=3.7" files = [ - {file = "hf_transfer-0.1.8-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:70858f9e94286738ed300484a45beb5cfee6a7ddac4c5886f9c6fce7823ac5ab"}, - {file = "hf_transfer-0.1.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38adc73f0a8526319d90f7cc5dc2d5e4bb66f487a513d94b98aa6725be732e4a"}, - {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44d2f0c08198d8d899fe9d66e86aee2dd844bd7ce33888f261373fcec81d2a54"}, - {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1de2a4ef36f9e60b3d3bec00193c0aafd75771709f2ca51b9b162373f5af3d32"}, - {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e319269e3606a5ff2979296841766649ac73598a4a8eee2a968f86c8071fea5a"}, - {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f6026cf3be6a53ea42f92172f60c1c0675baaa9073f865e671b661dde5fd157"}, - {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f865c33ada5bd3650c2b46e59979f2d7755c3f517f8d0facc78576a0c7d26406"}, - {file = "hf_transfer-0.1.8-cp310-none-win32.whl", hash = "sha256:2054730e8d8ed21917c64be7199e06424b2bd08df1c43a72766afaed7992f2d3"}, - {file = "hf_transfer-0.1.8-cp310-none-win_amd64.whl", hash = "sha256:2b4f1a9446ba31170b5b1eca4e916504d18378a6b5fe959896bdac8a736a5ecb"}, - {file = "hf_transfer-0.1.8-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e27c15fcc5869ad7e52bbc0bdec6106b288d1c463f8d2da92f28615a3b181361"}, - {file = "hf_transfer-0.1.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:871a0032d011ebc6409a73a8406b98b84ff2cd3ed7d9e1af8cdf4d660b9fab9b"}, - {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:686fa756e1e0214bb6327d33c66732c52274d94a8460beb50604ad988b391cf6"}, - {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:36a03b1b2911b0cf15b1b9d971a34b32dadcc4f2fd979aaff5979d6ce4017c34"}, - {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:079db90c81f41f4cf3227dfaaa855a9b8e9aef45bc7c2be29ce7232cd83ff881"}, - {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac08a4524127fdd14c234d4bcbe49d1c498acf5335c781714823179bcc8dc039"}, - {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837432e73cb17274a6782b6216e8ce058aa325a475dc44a5a6a753d48b86d18a"}, - {file = "hf_transfer-0.1.8-cp311-none-win32.whl", hash = "sha256:b180f9823dde35aba9bc0f1d0c04ac8a873baebd3732a7ffe4f11940abc7df0d"}, - {file = "hf_transfer-0.1.8-cp311-none-win_amd64.whl", hash = "sha256:37907d2135cebcf8b6d419bb575148d89c224f16b69357f027bd29d0e85c6529"}, - {file = "hf_transfer-0.1.8-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baf948f4f493949309cbe60529620b9b0aef854a22b6e526753364acc57c09b6"}, - {file = "hf_transfer-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bce5c8bdefa478c5d5eaa646cc4ce1df5cfe764d98572ad0c6b8773e98d49f6"}, - {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d6f8a1a86128d651a3799e1267c343d60f81f2c565d7c5416eb8e674e4cf0e"}, - {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f79fd1b0c2ed93efb4c5f684118d7a762ecdd218e170df8208c4e13d3dcd4959"}, - {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:414df35692670683bf5623498ef9d88a8df5d77e9516515da6e2b34d1054c11f"}, - {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c9798d5f951f66b96d40a7a53910260cb5874fda56cf5944dddb7c571f37ec3"}, - {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:060c661691f85a61392e57579c80eb64b5ee277434e81fb582f605c1c8ff05d5"}, - {file = "hf_transfer-0.1.8-cp312-none-win32.whl", hash = "sha256:f7840e32379820c3e1571a480238e05ea043e970c99d2e999578004a2eb17788"}, - {file = "hf_transfer-0.1.8-cp312-none-win_amd64.whl", hash = "sha256:9a3204ec423cc5e659872e8179f8704ad9ce2abb1e6a991f8838aedf1dc07830"}, - {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09949e86ad63ee139e463fd0dfaf401515ae70445854199f61d545514c65f744"}, - {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf1a74552845b93ea972e6e7131ef54e56056aa54137e93a40faf3fbcb2442ff"}, - {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959bcb3afb4ee6f2a07031a947dba98ec0b64c001bc914fbd8fc32e13a287162"}, - {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e01eecdb8162bd61dab9090fbd9f8034dd8b5755ef727a21ca8a057f80cb91ee"}, - {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50650a38e9d31f5ad8f010e4598bf304ecd99c17162e7d93f67e031571b864ee"}, - {file = "hf_transfer-0.1.8-cp37-none-win32.whl", hash = "sha256:e29b9d1d378138f2f4eae0e93ca94af3b5d45f4532eef69f1ab97fe06f9c9d9e"}, - {file = "hf_transfer-0.1.8-cp37-none-win_amd64.whl", hash = "sha256:cfd6cef43ae883103117a371f8ebae4e7f9637bc6fb480f1be5568e2fe22a8a7"}, - {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92a68f7a0043cca8a0de4decc760dca177530944cbab502afac503bd1b2fa01a"}, - {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e3138e408179f80a5480598e32f8e1abb564915cbde4d3bc8da52811c75dc3ea"}, - {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4544d148930ad34442d43b8fa911c8479c04a95b858b1d1f91e0b7da77082fad"}, - {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a851794b9f029965664f8c3002c957fccf21685e9397ceb4f9f19c986dee8ad3"}, - {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:791aaf87c5319ac83edb6ab2994b3db19924c49d6ff667dd3d8a610b455ff70a"}, - {file = "hf_transfer-0.1.8-cp38-none-win32.whl", hash = "sha256:8f71e5d35d3a3160dcca12fdcc8119033aeacaa6a32838a7ad9f9cb1008bbe58"}, - {file = "hf_transfer-0.1.8-cp38-none-win_amd64.whl", hash = "sha256:543287b4ceb1e25501580b99690f7f0df9d3631d29306f37cbd97e918c732944"}, - {file = "hf_transfer-0.1.8-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:7ce02a18bd0bb2343e707ac85b68c946bc37623ee24150c69158f6b2b2c7a98f"}, - {file = "hf_transfer-0.1.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64d7f8dbd64ba183ed1df75d47c84e075ff666ceaa335bff1de16b09eaac5b80"}, - {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e7858694e11419ae27e542fb8fc0d0e54d46ff7768fe73bc359d70b8f5aa578"}, - {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed116cd9d1edfa32c0136d7cb8e5f1afd2b32df43c49085d428f108fc8e1c8f"}, - {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e385d0da9c6b3472ab29285d2d46c9f9903205b8d108f88a82f3f85aafae0ab"}, - {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98f75fa4b86ef15433cd907807ac77d1fb39d7e7b790bfd39c7ae9c385bf0200"}, - {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a63ad947d2901425ac0a3ed70c3696dfde27fadb0482ed763bdd5cc946b278"}, - {file = "hf_transfer-0.1.8-cp39-none-win32.whl", hash = "sha256:3e74096915813ae842ea6a5bdf10c0fef960aa51a35a560955b3e61cdfe3db57"}, - {file = "hf_transfer-0.1.8-cp39-none-win_amd64.whl", hash = "sha256:05ea16307bf4a5eb097cbc6e5057e4eb5e080a138af23ef639fd38857723c288"}, - {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:928ff036c3e98e10dcfbdb4fcdfc4592d37a5cc8e365a7ba8dfd4337e849d675"}, - {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d49ba3ce67035f460ae1924fe2feafec155cb535eec7f31ed5109c19064cd294"}, - {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b01f5872c62cfee3ec9ca5c738818296f69f8adf84b4d8d15f2a5601d9dda339"}, - {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:659d4212d50847a5165666bf43d67727679b4f694ef9c413613cc27093136527"}, - {file = "hf_transfer-0.1.8.tar.gz", hash = "sha256:26d229468152e7a3ec12664cac86b8c2800695fd85f9c9a96677a775cc04f0b3"}, + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, + {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, + {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, + {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, + {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, + {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, + {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, + {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, + {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, + {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, + {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, + {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, + {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, + {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, ] [[package]] name = "huggingface-hub" -version = "0.24.6" +version = "0.23.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.24.6-py3-none-any.whl", hash = "sha256:a990f3232aa985fe749bc9474060cbad75e8b2f115f6665a9fda5b9c97818970"}, - {file = "huggingface_hub-0.24.6.tar.gz", hash = "sha256:cc2579e761d070713eaa9c323e3debe39d5b464ae3a7261c39a9195b27bb8000"}, + {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, + {file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, ] [package.dependencies] @@ -926,63 +893,30 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] hf-transfer = ["hf-transfer (>=0.1.4)"] inference = ["aiohttp", "minijinja (>=1.0)"] -quality = ["mypy (==1.5.1)", "ruff (>=0.5.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] tensorflow-testing = ["keras (<3.0)", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] -torch = ["safetensors[torch]", "torch"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] -[[package]] -name = "humanfriendly" -version = "10.0" -description = "Human friendly output for text interfaces using Python" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -files = [ - {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"}, - {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"}, -] - -[package.dependencies] -pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} - [[package]] name = "idna" -version = "3.8" +version = "3.7" description = "Internationalized Domain Names in Applications (IDNA)" optional = false -python-versions = ">=3.6" +python-versions = ">=3.5" files = [ - {file = "idna-3.8-py3-none-any.whl", hash = "sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac"}, - {file = "idna-3.8.tar.gz", hash = "sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603"}, + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] -[[package]] -name = "importlib-metadata" -version = "8.4.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"}, - {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"}, -] - -[package.dependencies] -zipp = ">=0.5" - -[package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] - [[package]] name = "iniconfig" version = "2.0.0" @@ -994,6 +928,20 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "intel-openmp" +version = "2021.4.0" +description = "Intel OpenMP* Runtime Library" +optional = true +python-versions = "*" +files = [ + {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, +] + [[package]] name = "interegular" version = "0.3.3" @@ -1009,7 +957,7 @@ files = [ name = "jinja2" version = "3.1.4" description = "A very fast and expressive template engine." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, @@ -1026,7 +974,7 @@ i18n = ["Babel (>=2.7)"] name = "joblib" version = "1.4.2" description = "Lightweight pipelining with Python functions" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, @@ -1035,13 +983,13 @@ files = [ [[package]] name = "jsonschema" -version = "4.23.0" +version = "4.22.0" description = "An implementation of JSON Schema validation for Python" optional = true python-versions = ">=3.8" files = [ - {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, - {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, + {file = "jsonschema-4.22.0-py3-none-any.whl", hash = "sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802"}, + {file = "jsonschema-4.22.0.tar.gz", hash = "sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7"}, ] [package.dependencies] @@ -1052,7 +1000,7 @@ rpds-py = ">=0.7.1" [package.extras] format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] -format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] [[package]] name = "jsonschema-specifications" @@ -1070,13 +1018,13 @@ referencing = ">=0.31.0" [[package]] name = "lark" -version = "1.2.2" +version = "1.1.9" description = "a modern parsing library" optional = true -python-versions = ">=3.8" +python-versions = ">=3.6" files = [ - {file = "lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c"}, - {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"}, + {file = "lark-1.1.9-py3-none-any.whl", hash = "sha256:a0dd3a87289f8ccbb325901e4222e723e7d745dbfc1803eaf5f3d2ace19cf2db"}, + {file = "lark-1.1.9.tar.gz", hash = "sha256:15fa5236490824c2c4aba0e22d2d6d823575dcaf4cdd1848e34b6ad836240fba"}, ] [package.extras] @@ -1087,32 +1035,32 @@ regex = ["regex"] [[package]] name = "llvmlite" -version = "0.43.0" +version = "0.42.0" description = "lightweight wrapper around basic LLVM functionality" optional = true python-versions = ">=3.9" files = [ - {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, - {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, - {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"}, - {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"}, - {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"}, - {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"}, - {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"}, - {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"}, - {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"}, - {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"}, - {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"}, - {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"}, - {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"}, - {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"}, - {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"}, - {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"}, - {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"}, - {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"}, - {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"}, - {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"}, - {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, + {file = "llvmlite-0.42.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3366938e1bf63d26c34fbfb4c8e8d2ded57d11e0567d5bb243d89aab1eb56098"}, + {file = "llvmlite-0.42.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c35da49666a21185d21b551fc3caf46a935d54d66969d32d72af109b5e7d2b6f"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70f44ccc3c6220bd23e0ba698a63ec2a7d3205da0d848804807f37fc243e3f77"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d"}, + {file = "llvmlite-0.42.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d90edf400b4ceb3a0e776b6c6e4656d05c7187c439587e06f86afceb66d2be5"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ae511caed28beaf1252dbaf5f40e663f533b79ceb408c874c01754cafabb9cbf"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81e674c2fe85576e6c4474e8c7e7aba7901ac0196e864fe7985492b737dbab65"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb3975787f13eb97629052edb5017f6c170eebc1c14a0433e8089e5db43bcce6"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5bece0cdf77f22379f19b1959ccd7aee518afa4afbd3656c6365865f84903f9"}, + {file = "llvmlite-0.42.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e0c4c11c8c2aa9b0701f91b799cb9134a6a6de51444eff5a9087fc7c1384275"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:08fa9ab02b0d0179c688a4216b8939138266519aaa0aa94f1195a8542faedb56"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b2fce7d355068494d1e42202c7aff25d50c462584233013eb4470c33b995e3ee"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebe66a86dc44634b59a3bc860c7b20d26d9aaffcd30364ebe8ba79161a9121f4"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47494552559e00d81bfb836cf1c4d5a5062e54102cc5767d5aa1e77ccd2505c"}, + {file = "llvmlite-0.42.0-cp312-cp312-win_amd64.whl", hash = "sha256:05cb7e9b6ce69165ce4d1b994fbdedca0c62492e537b0cc86141b6e2c78d5888"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdd3888544538a94d7ec99e7c62a0cdd8833609c85f0c23fcb6c5c591aec60ad"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0936c2067a67fb8816c908d5457d63eba3e2b17e515c5fe00e5ee2bace06040"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a78ab89f1924fc11482209f6799a7a3fc74ddc80425a7a3e0e8174af0e9e2301"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7599b65c7af7abbc978dbf345712c60fd596aa5670496561cc10e8a71cebfb2"}, + {file = "llvmlite-0.42.0-cp39-cp39-win_amd64.whl", hash = "sha256:43d65cc4e206c2e902c1004dd5418417c4efa6c1d04df05c6c5675a27e8ca90e"}, + {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, ] [[package]] @@ -1137,7 +1085,7 @@ dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils name = "markupsafe" version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, @@ -1202,11 +1150,29 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "mkl" +version = "2021.4.0" +description = "Intel® oneAPI Math Kernel Library" +optional = true +python-versions = "*" +files = [ + {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, + {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, + {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, +] + +[package.dependencies] +intel-openmp = "==2021.*" +tbb = "==2021.*" + [[package]] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" -optional = false +optional = true python-versions = "*" files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, @@ -1223,7 +1189,7 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.5" description = "multidict implementation" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, @@ -1320,27 +1286,31 @@ files = [ [[package]] name = "multiprocess" -version = "0.70.16" +version = "0.70.15" description = "better multiprocessing and multithreading in Python" -optional = false -python-versions = ">=3.8" +optional = true +python-versions = ">=3.7" files = [ - {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, - {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, - {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, - {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, - {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, - {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, - {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, - {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, - {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, - {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, - {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, - {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, + {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, + {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, + {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, + {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"}, + {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"}, + {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"}, + {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"}, ] [package.dependencies] -dill = ">=0.3.8" +dill = ">=0.3.7" [[package]] name = "nest-asyncio" @@ -1357,7 +1327,7 @@ files = [ name = "networkx" version = "3.2.1" description = "Python package for creating and manipulating graphs and networks" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, @@ -1373,37 +1343,37 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "numba" -version = "0.60.0" +version = "0.59.1" description = "compiling Python code using LLVM" optional = true python-versions = ">=3.9" files = [ - {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, - {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, - {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"}, - {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"}, - {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"}, - {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"}, - {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"}, - {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"}, - {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"}, - {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"}, - {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"}, - {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"}, - {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"}, - {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"}, - {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"}, - {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"}, - {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"}, - {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"}, - {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"}, - {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"}, - {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"}, + {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, + {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, + {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, + {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, + {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, + {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, + {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, + {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, + {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, + {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, + {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, + {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, + {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, ] [package.dependencies] -llvmlite = "==0.43.*" -numpy = ">=1.22,<2.1" +llvmlite = "==0.42.*" +numpy = ">=1.22,<1.27" [[package]] name = "numpy" @@ -1454,7 +1424,7 @@ files = [ name = "nvidia-cublas-cu12" version = "12.1.3.1" description = "CUBLAS native runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, @@ -1465,7 +1435,7 @@ files = [ name = "nvidia-cuda-cupti-cu12" version = "12.1.105" description = "CUDA profiling tools runtime libs." -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, @@ -1476,7 +1446,7 @@ files = [ name = "nvidia-cuda-nvrtc-cu12" version = "12.1.105" description = "NVRTC native runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, @@ -1487,7 +1457,7 @@ files = [ name = "nvidia-cuda-runtime-cu12" version = "12.1.105" description = "CUDA Runtime native Libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, @@ -1496,13 +1466,12 @@ files = [ [[package]] name = "nvidia-cudnn-cu12" -version = "9.1.0.70" +version = "8.9.2.26" description = "cuDNN runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ - {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, - {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] [package.dependencies] @@ -1512,7 +1481,7 @@ nvidia-cublas-cu12 = "*" name = "nvidia-cufft-cu12" version = "11.0.2.54" description = "CUFFT native runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, @@ -1523,7 +1492,7 @@ files = [ name = "nvidia-curand-cu12" version = "10.3.2.106" description = "CURAND native runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, @@ -1534,7 +1503,7 @@ files = [ name = "nvidia-cusolver-cu12" version = "11.4.5.107" description = "CUDA solver native runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, @@ -1550,7 +1519,7 @@ nvidia-nvjitlink-cu12 = "*" name = "nvidia-cusparse-cu12" version = "12.1.0.106" description = "CUSPARSE native runtime libraries" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, @@ -1564,7 +1533,7 @@ nvidia-nvjitlink-cu12 = "*" name = "nvidia-nccl-cu12" version = "2.20.5" description = "NVIDIA Collective Communication Library (NCCL) Runtime" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, @@ -1573,21 +1542,20 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.6.20" +version = "12.5.40" description = "Nvidia JIT LTO Library" -optional = false +optional = true python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"}, - {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"}, - {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"}, + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] [[package]] name = "nvidia-nvtx-cu12" version = "12.1.105" description = "NVIDIA Tools Extension" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, @@ -1748,75 +1716,6 @@ files = [ {file = "opentelemetry_semantic_conventions-0.36b0.tar.gz", hash = "sha256:829dc221795467d98b773c04096e29be038d77526dc8d6ac76f546fb6279bf01"}, ] -[[package]] -name = "optimum" -version = "1.21.4" -description = "Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality." -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "optimum-1.21.4-py3-none-any.whl", hash = "sha256:d18f426bcfc7bed9754af0305e1a223100195bab804d47ea5e4f8b0562f23d46"}, - {file = "optimum-1.21.4.tar.gz", hash = "sha256:47e751f1068c81fec7846d10c67bf39c18e41fda6cd15566dfc85c3ea82262f7"}, -] - -[package.dependencies] -coloredlogs = "*" -datasets = "*" -huggingface-hub = ">=0.8.0" -numpy = "<2.0" -packaging = "*" -sympy = "*" -torch = ">=1.11" -transformers = {version = ">=4.29.0,<4.44.0", extras = ["sentencepiece"]} - -[package.extras] -amd = ["optimum-amd"] -benchmark = ["evaluate (>=0.2.0)", "optuna", "scikit-learn", "seqeval", "torchvision", "tqdm"] -dev = ["Pillow", "accelerate", "black (>=23.1,<24.0)", "diffusers (>=0.17.0)", "einops", "invisible-watermark", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "ruff (==0.1.5)", "sacremoses", "scikit-learn", "timm", "torchaudio", "torchvision"] -diffusers = ["diffusers"] -doc-build = ["accelerate"] -exporters = ["onnx", "onnxruntime", "timm"] -exporters-gpu = ["onnx", "onnxruntime-gpu", "timm"] -exporters-tf = ["h5py", "numpy (<1.24.0)", "onnx", "onnxruntime", "tensorflow (>=2.4,<=2.12.1)", "tf2onnx", "timm", "transformers[sentencepiece] (>=4.26.0,<4.38.0)"] -furiosa = ["optimum-furiosa"] -graphcore = ["optimum-graphcore"] -habana = ["optimum-habana", "transformers (>=4.43.0,<4.44.0)"] -intel = ["optimum-intel (>=1.18.0)"] -ipex = ["optimum-intel[ipex] (>=1.18.0)"] -neural-compressor = ["optimum-intel[neural-compressor] (>=1.18.0)"] -neuron = ["optimum-neuron[neuron] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"] -neuronx = ["optimum-neuron[neuronx] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"] -nncf = ["optimum-intel[nncf] (>=1.18.0)"] -onnxruntime = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime (>=1.11.0)", "protobuf (>=3.20.1)"] -onnxruntime-gpu = ["accelerate", "datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime-gpu (>=1.11.0)", "protobuf (>=3.20.1)"] -openvino = ["optimum-intel[openvino] (>=1.18.0)"] -quality = ["black (>=23.1,<24.0)", "ruff (==0.1.5)"] -tests = ["Pillow", "accelerate", "diffusers (>=0.17.0)", "einops", "invisible-watermark", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "sacremoses", "scikit-learn", "timm", "torchaudio", "torchvision"] - -[[package]] -name = "optimum-habana" -version = "1.13.2" -description = "Optimum Habana is the interface between the Hugging Face Transformers and Diffusers libraries and Habana's Gaudi processor (HPU). It provides a set of tools enabling easy model loading, training and inference on single- and multi-HPU settings for different downstream tasks." -optional = false -python-versions = "*" -files = [ - {file = "optimum_habana-1.13.2-py3-none-any.whl", hash = "sha256:fbec406db58e6e24b9d4cfdbc72ac88423eeef3edaf81645dbd502141a29d052"}, - {file = "optimum_habana-1.13.2.tar.gz", hash = "sha256:6e88c93860e33c8186f6dbd55a61ed287ed3cba241f289fcf4a4001d6e2a41f0"}, -] - -[package.dependencies] -accelerate = ">=0.33.0,<0.34.0" -diffusers = "0.29.2" -huggingface-hub = ">=0.23.2" -optimum = "*" -sentence-transformers = {version = "3.0.1", extras = ["train"]} -torch = "*" -transformers = ">=4.43.0,<4.44.0" - -[package.extras] -quality = ["hf-doc-builder", "ruff"] -tests = ["GitPython", "datasets", "optuna", "parameterized", "peft", "psutil", "pytest (<8.0.0)", "safetensors", "scipy", "sentencepiece", "timm", "torchsde"] - [[package]] name = "outlines" version = "0.0.36" @@ -1852,20 +1751,20 @@ test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets" [[package]] name = "packaging" -version = "24.1" +version = "24.0" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, - {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] [[package]] name = "pandas" version = "2.2.2" description = "Powerful data structures for data analysis, time series, and statistics" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, @@ -1938,7 +1837,7 @@ xml = ["lxml (>=4.9.2)"] name = "peft" version = "0.10.0" description = "Parameter-Efficient Fine-Tuning (PEFT)" -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "peft-0.10.0-py3-none-any.whl", hash = "sha256:d5249c97e818d3e31f92553c73c2953acd0ec12649b8b749afff7152cbc86cbb"}, @@ -1965,95 +1864,84 @@ test = ["black", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameter [[package]] name = "pillow" -version = "10.4.0" +version = "10.3.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" files = [ - {file = "pillow-10.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e"}, - {file = "pillow-10.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc"}, - {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e"}, - {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46"}, - {file = "pillow-10.4.0-cp310-cp310-win32.whl", hash = "sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984"}, - {file = "pillow-10.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141"}, - {file = "pillow-10.4.0-cp310-cp310-win_arm64.whl", hash = "sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1"}, - {file = "pillow-10.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c"}, - {file = "pillow-10.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319"}, - {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d"}, - {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696"}, - {file = "pillow-10.4.0-cp311-cp311-win32.whl", hash = "sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496"}, - {file = "pillow-10.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91"}, - {file = "pillow-10.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22"}, - {file = "pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94"}, - {file = "pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a"}, - {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b"}, - {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9"}, - {file = "pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42"}, - {file = "pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a"}, - {file = "pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9"}, - {file = "pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3"}, - {file = "pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc"}, - {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a"}, - {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309"}, - {file = "pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060"}, - {file = "pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea"}, - {file = "pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d"}, - {file = "pillow-10.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736"}, - {file = "pillow-10.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd"}, - {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84"}, - {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0"}, - {file = "pillow-10.4.0-cp38-cp38-win32.whl", hash = "sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e"}, - {file = "pillow-10.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab"}, - {file = "pillow-10.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d"}, - {file = "pillow-10.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c"}, - {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1"}, - {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df"}, - {file = "pillow-10.4.0-cp39-cp39-win32.whl", hash = "sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef"}, - {file = "pillow-10.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5"}, - {file = "pillow-10.4.0-cp39-cp39-win_arm64.whl", hash = "sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3"}, - {file = "pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06"}, + {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"}, + {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"}, + {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"}, + {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"}, + {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"}, + {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"}, + {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"}, + {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, + {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, + {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, + {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"}, + {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"}, + {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"}, + {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"}, + {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"}, + {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, + {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, ] [package.extras] -docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] @@ -2091,59 +1979,47 @@ twisted = ["twisted"] [[package]] name = "protobuf" -version = "3.20.3" -description = "Protocol Buffers" +version = "4.25.3" +description = "" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99"}, - {file = "protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e"}, - {file = "protobuf-3.20.3-cp310-cp310-win32.whl", hash = "sha256:28545383d61f55b57cf4df63eebd9827754fd2dc25f80c5253f9184235db242c"}, - {file = "protobuf-3.20.3-cp310-cp310-win_amd64.whl", hash = "sha256:67a3598f0a2dcbc58d02dd1928544e7d88f764b47d4a286202913f0b2801c2e7"}, - {file = "protobuf-3.20.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:899dc660cd599d7352d6f10d83c95df430a38b410c1b66b407a6b29265d66469"}, - {file = "protobuf-3.20.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e64857f395505ebf3d2569935506ae0dfc4a15cb80dc25261176c784662cdcc4"}, - {file = "protobuf-3.20.3-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:d9e4432ff660d67d775c66ac42a67cf2453c27cb4d738fc22cb53b5d84c135d4"}, - {file = "protobuf-3.20.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:74480f79a023f90dc6e18febbf7b8bac7508420f2006fabd512013c0c238f454"}, - {file = "protobuf-3.20.3-cp37-cp37m-win32.whl", hash = "sha256:b6cc7ba72a8850621bfec987cb72623e703b7fe2b9127a161ce61e61558ad905"}, - {file = "protobuf-3.20.3-cp37-cp37m-win_amd64.whl", hash = "sha256:8c0c984a1b8fef4086329ff8dd19ac77576b384079247c770f29cc8ce3afa06c"}, - {file = "protobuf-3.20.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:de78575669dddf6099a8a0f46a27e82a1783c557ccc38ee620ed8cc96d3be7d7"}, - {file = "protobuf-3.20.3-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:f4c42102bc82a51108e449cbb32b19b180022941c727bac0cfd50170341f16ee"}, - {file = "protobuf-3.20.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:44246bab5dd4b7fbd3c0c80b6f16686808fab0e4aca819ade6e8d294a29c7050"}, - {file = "protobuf-3.20.3-cp38-cp38-win32.whl", hash = "sha256:c02ce36ec760252242a33967d51c289fd0e1c0e6e5cc9397e2279177716add86"}, - {file = "protobuf-3.20.3-cp38-cp38-win_amd64.whl", hash = "sha256:447d43819997825d4e71bf5769d869b968ce96848b6479397e29fc24c4a5dfe9"}, - {file = "protobuf-3.20.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:398a9e0c3eaceb34ec1aee71894ca3299605fa8e761544934378bbc6c97de23b"}, - {file = "protobuf-3.20.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:bf01b5720be110540be4286e791db73f84a2b721072a3711efff6c324cdf074b"}, - {file = "protobuf-3.20.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:daa564862dd0d39c00f8086f88700fdbe8bc717e993a21e90711acfed02f2402"}, - {file = "protobuf-3.20.3-cp39-cp39-win32.whl", hash = "sha256:819559cafa1a373b7096a482b504ae8a857c89593cf3a25af743ac9ecbd23480"}, - {file = "protobuf-3.20.3-cp39-cp39-win_amd64.whl", hash = "sha256:03038ac1cfbc41aa21f6afcbcd357281d7521b4157926f30ebecc8d4ea59dcb7"}, - {file = "protobuf-3.20.3-py2.py3-none-any.whl", hash = "sha256:a7ca6d488aa8ff7f329d4c545b2dbad8ac31464f1d8b1c87ad1346717731e4db"}, - {file = "protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2"}, + {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"}, + {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"}, + {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"}, + {file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"}, + {file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"}, + {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"}, + {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"}, + {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"}, + {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, ] [[package]] name = "psutil" -version = "6.0.0" +version = "5.9.8" description = "Cross-platform lib for process and system monitoring in Python." -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ - {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, - {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, - {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, - {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, - {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, - {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, - {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, - {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, - {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, - {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, - {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, - {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, - {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, - {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, + {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, + {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, + {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, + {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, + {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, + {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, + {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, + {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, + {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, + {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, + {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, + {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, + {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, + {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, + {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, + {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, ] [package.extras] @@ -2162,186 +2038,162 @@ files = [ [[package]] name = "pyarrow" -version = "17.0.0" +version = "16.1.0" description = "Python library for Apache Arrow" -optional = false +optional = true python-versions = ">=3.8" files = [ - {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, - {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, - {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, - {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, - {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, - {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, - {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, - {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, - {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, - {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, - {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, - {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, - {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, - {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, + {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, + {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, + {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, + {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, + {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, + {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, ] [package.dependencies] numpy = ">=1.16.6" -[package.extras] -test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] - [[package]] name = "pydantic" -version = "2.8.2" +version = "2.7.1" description = "Data validation using Python type hints" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, - {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, + {file = "pydantic-2.7.1-py3-none-any.whl", hash = "sha256:e029badca45266732a9a79898a15ae2e8b14840b1eabbb25844be28f0b33f3d5"}, + {file = "pydantic-2.7.1.tar.gz", hash = "sha256:e9dbb5eada8abe4d9ae5f46b9939aead650cd2b68f249bb3a8139dbe125803cc"}, ] [package.dependencies] annotated-types = ">=0.4.0" -pydantic-core = "2.20.1" -typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} +pydantic-core = "2.18.2" +typing-extensions = ">=4.6.1" [package.extras] email = ["email-validator (>=2.0.0)"] [[package]] name = "pydantic-core" -version = "2.20.1" +version = "2.18.2" description = "Core functionality for Pydantic validation and serialization" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, - {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, - {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, - {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, - {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, - {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, - {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, - {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, - {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, - {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, - {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, - {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, - {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, - {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, - {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, - {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, - {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, - {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, - {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, - {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, - {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, - {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, - {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, - {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, - {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, - {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, - {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, - {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, - {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, - {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, - {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, - {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, - {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, - {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, - {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, + {file = "pydantic_core-2.18.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9e08e867b306f525802df7cd16c44ff5ebbe747ff0ca6cf3fde7f36c05a59a81"}, + {file = "pydantic_core-2.18.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f0a21cbaa69900cbe1a2e7cad2aa74ac3cf21b10c3efb0fa0b80305274c0e8a2"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0680b1f1f11fda801397de52c36ce38ef1c1dc841a0927a94f226dea29c3ae3d"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:95b9d5e72481d3780ba3442eac863eae92ae43a5f3adb5b4d0a1de89d42bb250"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fcf5cd9c4b655ad666ca332b9a081112cd7a58a8b5a6ca7a3104bc950f2038"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b5155ff768083cb1d62f3e143b49a8a3432e6789a3abee8acd005c3c7af1c74"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:553ef617b6836fc7e4df130bb851e32fe357ce36336d897fd6646d6058d980af"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89ed9eb7d616ef5714e5590e6cf7f23b02d0d539767d33561e3675d6f9e3857"}, + {file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:75f7e9488238e920ab6204399ded280dc4c307d034f3924cd7f90a38b1829563"}, + {file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ef26c9e94a8c04a1b2924149a9cb081836913818e55681722d7f29af88fe7b38"}, + {file = "pydantic_core-2.18.2-cp310-none-win32.whl", hash = "sha256:182245ff6b0039e82b6bb585ed55a64d7c81c560715d1bad0cbad6dfa07b4027"}, + {file = "pydantic_core-2.18.2-cp310-none-win_amd64.whl", hash = "sha256:e23ec367a948b6d812301afc1b13f8094ab7b2c280af66ef450efc357d2ae543"}, + {file = "pydantic_core-2.18.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:219da3f096d50a157f33645a1cf31c0ad1fe829a92181dd1311022f986e5fbe3"}, + {file = "pydantic_core-2.18.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cc1cfd88a64e012b74e94cd00bbe0f9c6df57049c97f02bb07d39e9c852e19a4"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b7133a6e6aeb8df37d6f413f7705a37ab4031597f64ab56384c94d98fa0e90"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:224c421235f6102e8737032483f43c1a8cfb1d2f45740c44166219599358c2cd"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b14d82cdb934e99dda6d9d60dc84a24379820176cc4a0d123f88df319ae9c150"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2728b01246a3bba6de144f9e3115b532ee44bd6cf39795194fb75491824a1413"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:470b94480bb5ee929f5acba6995251ada5e059a5ef3e0dfc63cca287283ebfa6"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:997abc4df705d1295a42f95b4eec4950a37ad8ae46d913caeee117b6b198811c"}, + {file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:75250dbc5290e3f1a0f4618db35e51a165186f9034eff158f3d490b3fed9f8a0"}, + {file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4456f2dca97c425231d7315737d45239b2b51a50dc2b6f0c2bb181fce6207664"}, + {file = "pydantic_core-2.18.2-cp311-none-win32.whl", hash = "sha256:269322dcc3d8bdb69f054681edff86276b2ff972447863cf34c8b860f5188e2e"}, + {file = "pydantic_core-2.18.2-cp311-none-win_amd64.whl", hash = "sha256:800d60565aec896f25bc3cfa56d2277d52d5182af08162f7954f938c06dc4ee3"}, + {file = "pydantic_core-2.18.2-cp311-none-win_arm64.whl", hash = "sha256:1404c69d6a676245199767ba4f633cce5f4ad4181f9d0ccb0577e1f66cf4c46d"}, + {file = "pydantic_core-2.18.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:fb2bd7be70c0fe4dfd32c951bc813d9fe6ebcbfdd15a07527796c8204bd36242"}, + {file = "pydantic_core-2.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6132dd3bd52838acddca05a72aafb6eab6536aa145e923bb50f45e78b7251043"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d904828195733c183d20a54230c0df0eb46ec746ea1a666730787353e87182"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9bd70772c720142be1020eac55f8143a34ec9f82d75a8e7a07852023e46617f"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b8ed04b3582771764538f7ee7001b02e1170223cf9b75dff0bc698fadb00cf3"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6dac87ddb34aaec85f873d737e9d06a3555a1cc1a8e0c44b7f8d5daeb89d86f"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca4ae5a27ad7a4ee5170aebce1574b375de390bc01284f87b18d43a3984df72"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:886eec03591b7cf058467a70a87733b35f44707bd86cf64a615584fd72488b7c"}, + {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca7b0c1f1c983e064caa85f3792dd2fe3526b3505378874afa84baf662e12241"}, + {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b4356d3538c3649337df4074e81b85f0616b79731fe22dd11b99499b2ebbdf3"}, + {file = "pydantic_core-2.18.2-cp312-none-win32.whl", hash = "sha256:8b172601454f2d7701121bbec3425dd71efcb787a027edf49724c9cefc14c038"}, + {file = "pydantic_core-2.18.2-cp312-none-win_amd64.whl", hash = "sha256:b1bd7e47b1558ea872bd16c8502c414f9e90dcf12f1395129d7bb42a09a95438"}, + {file = "pydantic_core-2.18.2-cp312-none-win_arm64.whl", hash = "sha256:98758d627ff397e752bc339272c14c98199c613f922d4a384ddc07526c86a2ec"}, + {file = "pydantic_core-2.18.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9fdad8e35f278b2c3eb77cbdc5c0a49dada440657bf738d6905ce106dc1de439"}, + {file = "pydantic_core-2.18.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d90c3265ae107f91a4f279f4d6f6f1d4907ac76c6868b27dc7fb33688cfb347"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:390193c770399861d8df9670fb0d1874f330c79caaca4642332df7c682bf6b91"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:82d5d4d78e4448683cb467897fe24e2b74bb7b973a541ea1dcfec1d3cbce39fb"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4774f3184d2ef3e14e8693194f661dea5a4d6ca4e3dc8e39786d33a94865cefd"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4d938ec0adf5167cb335acb25a4ee69a8107e4984f8fbd2e897021d9e4ca21b"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e8b1be28239fc64a88a8189d1df7fad8be8c1ae47fcc33e43d4be15f99cc70"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:868649da93e5a3d5eacc2b5b3b9235c98ccdbfd443832f31e075f54419e1b96b"}, + {file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:78363590ef93d5d226ba21a90a03ea89a20738ee5b7da83d771d283fd8a56761"}, + {file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:852e966fbd035a6468fc0a3496589b45e2208ec7ca95c26470a54daed82a0788"}, + {file = "pydantic_core-2.18.2-cp38-none-win32.whl", hash = "sha256:6a46e22a707e7ad4484ac9ee9f290f9d501df45954184e23fc29408dfad61350"}, + {file = "pydantic_core-2.18.2-cp38-none-win_amd64.whl", hash = "sha256:d91cb5ea8b11607cc757675051f61b3d93f15eca3cefb3e6c704a5d6e8440f4e"}, + {file = "pydantic_core-2.18.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ae0a8a797a5e56c053610fa7be147993fe50960fa43609ff2a9552b0e07013e8"}, + {file = "pydantic_core-2.18.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:042473b6280246b1dbf530559246f6842b56119c2926d1e52b631bdc46075f2a"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a388a77e629b9ec814c1b1e6b3b595fe521d2cdc625fcca26fbc2d44c816804"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25add29b8f3b233ae90ccef2d902d0ae0432eb0d45370fe315d1a5cf231004b"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f459a5ce8434614dfd39bbebf1041952ae01da6bed9855008cb33b875cb024c0"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eff2de745698eb46eeb51193a9f41d67d834d50e424aef27df2fcdee1b153845"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8309f67285bdfe65c372ea3722b7a5642680f3dba538566340a9d36e920b5f0"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f93a8a2e3938ff656a7c1bc57193b1319960ac015b6e87d76c76bf14fe0244b4"}, + {file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:22057013c8c1e272eb8d0eebc796701167d8377441ec894a8fed1af64a0bf399"}, + {file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfeecd1ac6cc1fb2692c3d5110781c965aabd4ec5d32799773ca7b1456ac636b"}, + {file = "pydantic_core-2.18.2-cp39-none-win32.whl", hash = "sha256:0d69b4c2f6bb3e130dba60d34c0845ba31b69babdd3f78f7c0c8fae5021a253e"}, + {file = "pydantic_core-2.18.2-cp39-none-win_amd64.whl", hash = "sha256:d9319e499827271b09b4e411905b24a426b8fb69464dfa1696258f53a3334641"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a1874c6dd4113308bd0eb568418e6114b252afe44319ead2b4081e9b9521fe75"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:ccdd111c03bfd3666bd2472b674c6899550e09e9f298954cfc896ab92b5b0e6d"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e18609ceaa6eed63753037fc06ebb16041d17d28199ae5aba0052c51449650a9"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5c584d357c4e2baf0ff7baf44f4994be121e16a2c88918a5817331fc7599d7"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43f0f463cf89ace478de71a318b1b4f05ebc456a9b9300d027b4b57c1a2064fb"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e1b395e58b10b73b07b7cf740d728dd4ff9365ac46c18751bf8b3d8cca8f625a"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0098300eebb1c837271d3d1a2cd2911e7c11b396eac9661655ee524a7f10587b"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:36789b70d613fbac0a25bb07ab3d9dba4d2e38af609c020cf4d888d165ee0bf3"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f9a801e7c8f1ef8718da265bba008fa121243dfe37c1cea17840b0944dfd72c"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3a6515ebc6e69d85502b4951d89131ca4e036078ea35533bb76327f8424531ce"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20aca1e2298c56ececfd8ed159ae4dde2df0781988c97ef77d5c16ff4bd5b400"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:223ee893d77a310a0391dca6df00f70bbc2f36a71a895cecd9a0e762dc37b349"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2334ce8c673ee93a1d6a65bd90327588387ba073c17e61bf19b4fd97d688d63c"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cbca948f2d14b09d20268cda7b0367723d79063f26c4ffc523af9042cad95592"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b3ef08e20ec49e02d5c6717a91bb5af9b20f1805583cb0adfe9ba2c6b505b5ae"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6fdc8627910eed0c01aed6a390a252fe3ea6d472ee70fdde56273f198938374"}, + {file = "pydantic_core-2.18.2.tar.gz", hash = "sha256:2e29d20810dfc3043ee13ac7d9e25105799817683348823f305ab3f349b9386e"}, ] [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" -[[package]] -name = "pyreadline3" -version = "3.4.1" -description = "A python implementation of GNU readline." -optional = false -python-versions = "*" -files = [ - {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, - {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, -] - [[package]] name = "pytest" version = "7.4.4" @@ -2368,7 +2220,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "python-dateutil" version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" -optional = false +optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, @@ -2382,7 +2234,7 @@ six = ">=1.5" name = "pytz" version = "2024.1" description = "World timezone definitions, modern and historical" -optional = false +optional = true python-versions = "*" files = [ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, @@ -2391,64 +2243,62 @@ files = [ [[package]] name = "pyyaml" -version = "6.0.2" +version = "6.0.1" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.6" files = [ - {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, - {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, - {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, - {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, - {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, - {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, - {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, - {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, - {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, - {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, - {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, - {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, - {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, - {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, - {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, - {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, - {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, - {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, - {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, - {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, - {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, - {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, - {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, - {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, - {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, - {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, - {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, - {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, - {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, - {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, - {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, - {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, - {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, - {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, - {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, - {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, - {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, - {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, - {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, - {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, - {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, - {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, - {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, - {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, - {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, - {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, - {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, - {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, - {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, - {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, - {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, - {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, - {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] [[package]] @@ -2468,101 +2318,101 @@ rpds-py = ">=0.7.0" [[package]] name = "regex" -version = "2024.7.24" +version = "2024.5.15" description = "Alternative regular expression module, to replace re." optional = false python-versions = ">=3.8" files = [ - {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b0d3f567fafa0633aee87f08b9276c7062da9616931382993c03808bb68ce"}, - {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3426de3b91d1bc73249042742f45c2148803c111d1175b283270177fdf669024"}, - {file = "regex-2024.7.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f273674b445bcb6e4409bf8d1be67bc4b58e8b46fd0d560055d515b8830063cd"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23acc72f0f4e1a9e6e9843d6328177ae3074b4182167e34119ec7233dfeccf53"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65fd3d2e228cae024c411c5ccdffae4c315271eee4a8b839291f84f796b34eca"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c414cbda77dbf13c3bc88b073a1a9f375c7b0cb5e115e15d4b73ec3a2fbc6f59"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7a89eef64b5455835f5ed30254ec19bf41f7541cd94f266ab7cbd463f00c41"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19c65b00d42804e3fbea9708f0937d157e53429a39b7c61253ff15670ff62cb5"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7a5486ca56c8869070a966321d5ab416ff0f83f30e0e2da1ab48815c8d165d46"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f51f9556785e5a203713f5efd9c085b4a45aecd2a42573e2b5041881b588d1f"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a4997716674d36a82eab3e86f8fa77080a5d8d96a389a61ea1d0e3a94a582cf7"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c0abb5e4e8ce71a61d9446040c1e86d4e6d23f9097275c5bd49ed978755ff0fe"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:18300a1d78cf1290fa583cd8b7cde26ecb73e9f5916690cf9d42de569c89b1ce"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:416c0e4f56308f34cdb18c3f59849479dde5b19febdcd6e6fa4d04b6c31c9faa"}, - {file = "regex-2024.7.24-cp310-cp310-win32.whl", hash = "sha256:fb168b5924bef397b5ba13aabd8cf5df7d3d93f10218d7b925e360d436863f66"}, - {file = "regex-2024.7.24-cp310-cp310-win_amd64.whl", hash = "sha256:6b9fc7e9cc983e75e2518496ba1afc524227c163e43d706688a6bb9eca41617e"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:382281306e3adaaa7b8b9ebbb3ffb43358a7bbf585fa93821300a418bb975281"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fdd1384619f406ad9037fe6b6eaa3de2749e2e12084abc80169e8e075377d3b"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3d974d24edb231446f708c455fd08f94c41c1ff4f04bcf06e5f36df5ef50b95a"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec4419a3fe6cf8a4795752596dfe0adb4aea40d3683a132bae9c30b81e8d73"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb563dd3aea54c797adf513eeec819c4213d7dbfc311874eb4fd28d10f2ff0f2"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45104baae8b9f67569f0f1dca5e1f1ed77a54ae1cd8b0b07aba89272710db61e"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:994448ee01864501912abf2bad9203bffc34158e80fe8bfb5b031f4f8e16da51"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fac296f99283ac232d8125be932c5cd7644084a30748fda013028c815ba3364"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7e37e809b9303ec3a179085415cb5f418ecf65ec98cdfe34f6a078b46ef823ee"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:01b689e887f612610c869421241e075c02f2e3d1ae93a037cb14f88ab6a8934c"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f6442f0f0ff81775eaa5b05af8a0ffa1dda36e9cf6ec1e0d3d245e8564b684ce"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:871e3ab2838fbcb4e0865a6e01233975df3a15e6fce93b6f99d75cacbd9862d1"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c918b7a1e26b4ab40409820ddccc5d49871a82329640f5005f73572d5eaa9b5e"}, - {file = "regex-2024.7.24-cp311-cp311-win32.whl", hash = "sha256:2dfbb8baf8ba2c2b9aa2807f44ed272f0913eeeba002478c4577b8d29cde215c"}, - {file = "regex-2024.7.24-cp311-cp311-win_amd64.whl", hash = "sha256:538d30cd96ed7d1416d3956f94d54e426a8daf7c14527f6e0d6d425fcb4cca52"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fe4ebef608553aff8deb845c7f4f1d0740ff76fa672c011cc0bacb2a00fbde86"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:74007a5b25b7a678459f06559504f1eec2f0f17bca218c9d56f6a0a12bfffdad"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7df9ea48641da022c2a3c9c641650cd09f0cd15e8908bf931ad538f5ca7919c9"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1141a1dcc32904c47f6846b040275c6e5de0bf73f17d7a409035d55b76f289"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80c811cfcb5c331237d9bad3bea2c391114588cf4131707e84d9493064d267f9"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7214477bf9bd195894cf24005b1e7b496f46833337b5dedb7b2a6e33f66d962c"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d55588cba7553f0b6ec33130bc3e114b355570b45785cebdc9daed8c637dd440"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558a57cfc32adcf19d3f791f62b5ff564922942e389e3cfdb538a23d65a6b610"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a512eed9dfd4117110b1881ba9a59b31433caed0c4101b361f768e7bcbaf93c5"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:86b17ba823ea76256b1885652e3a141a99a5c4422f4a869189db328321b73799"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5eefee9bfe23f6df09ffb6dfb23809f4d74a78acef004aa904dc7c88b9944b05"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:731fcd76bbdbf225e2eb85b7c38da9633ad3073822f5ab32379381e8c3c12e94"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eaef80eac3b4cfbdd6de53c6e108b4c534c21ae055d1dbea2de6b3b8ff3def38"}, - {file = "regex-2024.7.24-cp312-cp312-win32.whl", hash = "sha256:185e029368d6f89f36e526764cf12bf8d6f0e3a2a7737da625a76f594bdfcbfc"}, - {file = "regex-2024.7.24-cp312-cp312-win_amd64.whl", hash = "sha256:2f1baff13cc2521bea83ab2528e7a80cbe0ebb2c6f0bfad15be7da3aed443908"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:66b4c0731a5c81921e938dcf1a88e978264e26e6ac4ec96a4d21ae0354581ae0"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:88ecc3afd7e776967fa16c80f974cb79399ee8dc6c96423321d6f7d4b881c92b"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64bd50cf16bcc54b274e20235bf8edbb64184a30e1e53873ff8d444e7ac656b2"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb462f0e346fcf41a901a126b50f8781e9a474d3927930f3490f38a6e73b6950"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a82465ebbc9b1c5c50738536fdfa7cab639a261a99b469c9d4c7dcbb2b3f1e57"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68a8f8c046c6466ac61a36b65bb2395c74451df2ffb8458492ef49900efed293"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac8e84fff5d27420f3c1e879ce9929108e873667ec87e0c8eeb413a5311adfe"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba2537ef2163db9e6ccdbeb6f6424282ae4dea43177402152c67ef869cf3978b"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:43affe33137fcd679bdae93fb25924979517e011f9dea99163f80b82eadc7e53"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:c9bb87fdf2ab2370f21e4d5636e5317775e5d51ff32ebff2cf389f71b9b13750"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:945352286a541406f99b2655c973852da7911b3f4264e010218bbc1cc73168f2"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8bc593dcce679206b60a538c302d03c29b18e3d862609317cb560e18b66d10cf"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:3f3b6ca8eae6d6c75a6cff525c8530c60e909a71a15e1b731723233331de4169"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c51edc3541e11fbe83f0c4d9412ef6c79f664a3745fab261457e84465ec9d5a8"}, - {file = "regex-2024.7.24-cp38-cp38-win32.whl", hash = "sha256:d0a07763776188b4db4c9c7fb1b8c494049f84659bb387b71c73bbc07f189e96"}, - {file = "regex-2024.7.24-cp38-cp38-win_amd64.whl", hash = "sha256:8fd5afd101dcf86a270d254364e0e8dddedebe6bd1ab9d5f732f274fa00499a5"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0ffe3f9d430cd37d8fa5632ff6fb36d5b24818c5c986893063b4e5bdb84cdf24"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25419b70ba00a16abc90ee5fce061228206173231f004437730b67ac77323f0d"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33e2614a7ce627f0cdf2ad104797d1f68342d967de3695678c0cb84f530709f8"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33a0021893ede5969876052796165bab6006559ab845fd7b515a30abdd990dc"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04ce29e2c5fedf296b1a1b0acc1724ba93a36fb14031f3abfb7abda2806c1535"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b16582783f44fbca6fcf46f61347340c787d7530d88b4d590a397a47583f31dd"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:836d3cc225b3e8a943d0b02633fb2f28a66e281290302a79df0e1eaa984ff7c1"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:438d9f0f4bc64e8dea78274caa5af971ceff0f8771e1a2333620969936ba10be"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:973335b1624859cb0e52f96062a28aa18f3a5fc77a96e4a3d6d76e29811a0e6e"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c5e69fd3eb0b409432b537fe3c6f44ac089c458ab6b78dcec14478422879ec5f"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fbf8c2f00904eaf63ff37718eb13acf8e178cb940520e47b2f05027f5bb34ce3"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2757ace61bc4061b69af19e4689fa4416e1a04840f33b441034202b5cd02d4"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:44fc61b99035fd9b3b9453f1713234e5a7c92a04f3577252b45feefe1b327759"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:84c312cdf839e8b579f504afcd7b65f35d60b6285d892b19adea16355e8343c9"}, - {file = "regex-2024.7.24-cp39-cp39-win32.whl", hash = "sha256:ca5b2028c2f7af4e13fb9fc29b28d0ce767c38c7facdf64f6c2cd040413055f1"}, - {file = "regex-2024.7.24-cp39-cp39-win_amd64.whl", hash = "sha256:7c479f5ae937ec9985ecaf42e2e10631551d909f203e31308c12d703922742f9"}, - {file = "regex-2024.7.24.tar.gz", hash = "sha256:9cfd009eed1a46b27c14039ad5bbc5e71b6367c5b2e6d5f5da0ea91600817506"}, + {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f"}, + {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6"}, + {file = "regex-2024.5.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0bd000c6e266927cb7a1bc39d55be95c4b4f65c5be53e659537537e019232b1"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eaa7ddaf517aa095fa8da0b5015c44d03da83f5bd49c87961e3c997daed0de7"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba68168daedb2c0bab7fd7e00ced5ba90aebf91024dea3c88ad5063c2a562cca"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e8d717bca3a6e2064fc3a08df5cbe366369f4b052dcd21b7416e6d71620dca1"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1337b7dbef9b2f71121cdbf1e97e40de33ff114801263b275aafd75303bd62b5"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9ebd0a36102fcad2f03696e8af4ae682793a5d30b46c647eaf280d6cfb32796"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9efa1a32ad3a3ea112224897cdaeb6aa00381627f567179c0314f7b65d354c62"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1595f2d10dff3d805e054ebdc41c124753631b6a471b976963c7b28543cf13b0"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b802512f3e1f480f41ab5f2cfc0e2f761f08a1f41092d6718868082fc0d27143"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a0981022dccabca811e8171f913de05720590c915b033b7e601f35ce4ea7019f"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:19068a6a79cf99a19ccefa44610491e9ca02c2be3305c7760d3831d38a467a6f"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b5269484f6126eee5e687785e83c6b60aad7663dafe842b34691157e5083e53"}, + {file = "regex-2024.5.15-cp310-cp310-win32.whl", hash = "sha256:ada150c5adfa8fbcbf321c30c751dc67d2f12f15bd183ffe4ec7cde351d945b3"}, + {file = "regex-2024.5.15-cp310-cp310-win_amd64.whl", hash = "sha256:ac394ff680fc46b97487941f5e6ae49a9f30ea41c6c6804832063f14b2a5a145"}, + {file = "regex-2024.5.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f5b1dff3ad008dccf18e652283f5e5339d70bf8ba7c98bf848ac33db10f7bc7a"}, + {file = "regex-2024.5.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6a2b494a76983df8e3d3feea9b9ffdd558b247e60b92f877f93a1ff43d26656"}, + {file = "regex-2024.5.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a32b96f15c8ab2e7d27655969a23895eb799de3665fa94349f3b2fbfd547236f"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10002e86e6068d9e1c91eae8295ef690f02f913c57db120b58fdd35a6bb1af35"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec54d5afa89c19c6dd8541a133be51ee1017a38b412b1321ccb8d6ddbeb4cf7d"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10e4ce0dca9ae7a66e6089bb29355d4432caed736acae36fef0fdd7879f0b0cb"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e507ff1e74373c4d3038195fdd2af30d297b4f0950eeda6f515ae3d84a1770f"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1f059a4d795e646e1c37665b9d06062c62d0e8cc3c511fe01315973a6542e40"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0721931ad5fe0dda45d07f9820b90b2148ccdd8e45bb9e9b42a146cb4f695649"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:833616ddc75ad595dee848ad984d067f2f31be645d603e4d158bba656bbf516c"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:287eb7f54fc81546346207c533ad3c2c51a8d61075127d7f6d79aaf96cdee890"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:19dfb1c504781a136a80ecd1fff9f16dddf5bb43cec6871778c8a907a085bb3d"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:119af6e56dce35e8dfb5222573b50c89e5508d94d55713c75126b753f834de68"}, + {file = "regex-2024.5.15-cp311-cp311-win32.whl", hash = "sha256:1c1c174d6ec38d6c8a7504087358ce9213d4332f6293a94fbf5249992ba54efa"}, + {file = "regex-2024.5.15-cp311-cp311-win_amd64.whl", hash = "sha256:9e717956dcfd656f5055cc70996ee2cc82ac5149517fc8e1b60261b907740201"}, + {file = "regex-2024.5.15-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:632b01153e5248c134007209b5c6348a544ce96c46005d8456de1d552455b014"}, + {file = "regex-2024.5.15-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e64198f6b856d48192bf921421fdd8ad8eb35e179086e99e99f711957ffedd6e"}, + {file = "regex-2024.5.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68811ab14087b2f6e0fc0c2bae9ad689ea3584cad6917fc57be6a48bbd012c49"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ec0c2fea1e886a19c3bee0cd19d862b3aa75dcdfb42ebe8ed30708df64687a"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0c0c0003c10f54a591d220997dd27d953cd9ccc1a7294b40a4be5312be8797b"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2431b9e263af1953c55abbd3e2efca67ca80a3de8a0437cb58e2421f8184717a"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a605586358893b483976cffc1723fb0f83e526e8f14c6e6614e75919d9862cf"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391d7f7f1e409d192dba8bcd42d3e4cf9e598f3979cdaed6ab11288da88cb9f2"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9ff11639a8d98969c863d4617595eb5425fd12f7c5ef6621a4b74b71ed8726d5"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4eee78a04e6c67e8391edd4dad3279828dd66ac4b79570ec998e2155d2e59fd5"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8fe45aa3f4aa57faabbc9cb46a93363edd6197cbc43523daea044e9ff2fea83e"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:d0a3d8d6acf0c78a1fff0e210d224b821081330b8524e3e2bc5a68ef6ab5803d"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c486b4106066d502495b3025a0a7251bf37ea9540433940a23419461ab9f2a80"}, + {file = "regex-2024.5.15-cp312-cp312-win32.whl", hash = "sha256:c49e15eac7c149f3670b3e27f1f28a2c1ddeccd3a2812cba953e01be2ab9b5fe"}, + {file = "regex-2024.5.15-cp312-cp312-win_amd64.whl", hash = "sha256:673b5a6da4557b975c6c90198588181029c60793835ce02f497ea817ff647cb2"}, + {file = "regex-2024.5.15-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:87e2a9c29e672fc65523fb47a90d429b70ef72b901b4e4b1bd42387caf0d6835"}, + {file = "regex-2024.5.15-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c3bea0ba8b73b71b37ac833a7f3fd53825924165da6a924aec78c13032f20850"}, + {file = "regex-2024.5.15-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bfc4f82cabe54f1e7f206fd3d30fda143f84a63fe7d64a81558d6e5f2e5aaba9"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5bb9425fe881d578aeca0b2b4b3d314ec88738706f66f219c194d67179337cb"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64c65783e96e563103d641760664125e91bd85d8e49566ee560ded4da0d3e704"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf2430df4148b08fb4324b848672514b1385ae3807651f3567871f130a728cc3"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5397de3219a8b08ae9540c48f602996aa6b0b65d5a61683e233af8605c42b0f2"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:455705d34b4154a80ead722f4f185b04c4237e8e8e33f265cd0798d0e44825fa"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2b6f1b3bb6f640c1a92be3bbfbcb18657b125b99ecf141fb3310b5282c7d4ed"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3ad070b823ca5890cab606c940522d05d3d22395d432f4aaaf9d5b1653e47ced"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5b5467acbfc153847d5adb21e21e29847bcb5870e65c94c9206d20eb4e99a384"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e6662686aeb633ad65be2a42b4cb00178b3fbf7b91878f9446075c404ada552f"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:2b4c884767504c0e2401babe8b5b7aea9148680d2e157fa28f01529d1f7fcf67"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:3cd7874d57f13bf70078f1ff02b8b0aa48d5b9ed25fc48547516c6aba36f5741"}, + {file = "regex-2024.5.15-cp38-cp38-win32.whl", hash = "sha256:e4682f5ba31f475d58884045c1a97a860a007d44938c4c0895f41d64481edbc9"}, + {file = "regex-2024.5.15-cp38-cp38-win_amd64.whl", hash = "sha256:d99ceffa25ac45d150e30bd9ed14ec6039f2aad0ffa6bb87a5936f5782fc1569"}, + {file = "regex-2024.5.15-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13cdaf31bed30a1e1c2453ef6015aa0983e1366fad2667657dbcac7b02f67133"}, + {file = "regex-2024.5.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cac27dcaa821ca271855a32188aa61d12decb6fe45ffe3e722401fe61e323cd1"}, + {file = "regex-2024.5.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7dbe2467273b875ea2de38ded4eba86cbcbc9a1a6d0aa11dcf7bd2e67859c435"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64f18a9a3513a99c4bef0e3efd4c4a5b11228b48aa80743be822b71e132ae4f5"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d347a741ea871c2e278fde6c48f85136c96b8659b632fb57a7d1ce1872547600"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1878b8301ed011704aea4c806a3cadbd76f84dece1ec09cc9e4dc934cfa5d4da"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4babf07ad476aaf7830d77000874d7611704a7fcf68c9c2ad151f5d94ae4bfc4"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35cb514e137cb3488bce23352af3e12fb0dbedd1ee6e60da053c69fb1b29cc6c"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cdd09d47c0b2efee9378679f8510ee6955d329424c659ab3c5e3a6edea696294"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:72d7a99cd6b8f958e85fc6ca5b37c4303294954eac1376535b03c2a43eb72629"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a094801d379ab20c2135529948cb84d417a2169b9bdceda2a36f5f10977ebc16"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c0c18345010870e58238790a6779a1219b4d97bd2e77e1140e8ee5d14df071aa"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:16093f563098448ff6b1fa68170e4acbef94e6b6a4e25e10eae8598bb1694b5d"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e38a7d4e8f633a33b4c7350fbd8bad3b70bf81439ac67ac38916c4a86b465456"}, + {file = "regex-2024.5.15-cp39-cp39-win32.whl", hash = "sha256:71a455a3c584a88f654b64feccc1e25876066c4f5ef26cd6dd711308aa538694"}, + {file = "regex-2024.5.15-cp39-cp39-win_amd64.whl", hash = "sha256:cab12877a9bdafde5500206d1020a584355a97884dfd388af3699e9137bf7388"}, + {file = "regex-2024.5.15.tar.gz", hash = "sha256:d3ee02d9e5f482cc8309134a91eeaacbdd2261ba111b0fef3748eeb4913e6a2c"}, ] [[package]] name = "requests" -version = "2.32.3" +version = "2.32.2" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, - {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] @@ -2577,233 +2427,219 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rpds-py" -version = "0.20.0" +version = "0.18.1" description = "Python bindings to Rust's persistent data structures (rpds)" optional = true python-versions = ">=3.8" files = [ - {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, - {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, - {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, - {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, - {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, - {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, - {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, - {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, - {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, - {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, - {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, - {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, - {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, - {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, - {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, - {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, - {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, - {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, - {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, - {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, - {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, - {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, - {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, - {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, - {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, - {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, - {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, - {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, - {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, - {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, - {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, - {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, - {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, - {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, - {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, - {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, - {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, - {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, - {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, - {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, - {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, - {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, - {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, + {file = "rpds_py-0.18.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d31dea506d718693b6b2cffc0648a8929bdc51c70a311b2770f09611caa10d53"}, + {file = "rpds_py-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:732672fbc449bab754e0b15356c077cc31566df874964d4801ab14f71951ea80"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a98a1f0552b5f227a3d6422dbd61bc6f30db170939bd87ed14f3c339aa6c7c9"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7f1944ce16401aad1e3f7d312247b3d5de7981f634dc9dfe90da72b87d37887d"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38e14fb4e370885c4ecd734f093a2225ee52dc384b86fa55fe3f74638b2cfb09"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08d74b184f9ab6289b87b19fe6a6d1a97fbfea84b8a3e745e87a5de3029bf944"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d70129cef4a8d979caa37e7fe957202e7eee8ea02c5e16455bc9808a59c6b2f0"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0bb20e3a11bd04461324a6a798af34d503f8d6f1aa3d2aa8901ceaf039176d"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81c5196a790032e0fc2464c0b4ab95f8610f96f1f2fa3d4deacce6a79852da60"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f3027be483868c99b4985fda802a57a67fdf30c5d9a50338d9db646d590198da"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d44607f98caa2961bab4fa3c4309724b185b464cdc3ba6f3d7340bac3ec97cc1"}, + {file = "rpds_py-0.18.1-cp310-none-win32.whl", hash = "sha256:c273e795e7a0f1fddd46e1e3cb8be15634c29ae8ff31c196debb620e1edb9333"}, + {file = "rpds_py-0.18.1-cp310-none-win_amd64.whl", hash = "sha256:8352f48d511de5f973e4f2f9412736d7dea76c69faa6d36bcf885b50c758ab9a"}, + {file = "rpds_py-0.18.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b5ff7e1d63a8281654b5e2896d7f08799378e594f09cf3674e832ecaf396ce8"}, + {file = "rpds_py-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8927638a4d4137a289e41d0fd631551e89fa346d6dbcfc31ad627557d03ceb6d"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:154bf5c93d79558b44e5b50cc354aa0459e518e83677791e6adb0b039b7aa6a7"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07f2139741e5deb2c5154a7b9629bc5aa48c766b643c1a6750d16f865a82c5fc"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c7672e9fba7425f79019db9945b16e308ed8bc89348c23d955c8c0540da0a07"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:489bdfe1abd0406eba6b3bb4fdc87c7fa40f1031de073d0cfb744634cc8fa261"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c20f05e8e3d4fc76875fc9cb8cf24b90a63f5a1b4c5b9273f0e8225e169b100"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:967342e045564cef76dfcf1edb700b1e20838d83b1aa02ab313e6a497cf923b8"}, + {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2cc7c1a47f3a63282ab0f422d90ddac4aa3034e39fc66a559ab93041e6505da7"}, + {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f7afbfee1157e0f9376c00bb232e80a60e59ed716e3211a80cb8506550671e6e"}, + {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e6934d70dc50f9f8ea47081ceafdec09245fd9f6032669c3b45705dea096b88"}, + {file = "rpds_py-0.18.1-cp311-none-win32.whl", hash = "sha256:c69882964516dc143083d3795cb508e806b09fc3800fd0d4cddc1df6c36e76bb"}, + {file = "rpds_py-0.18.1-cp311-none-win_amd64.whl", hash = "sha256:70a838f7754483bcdc830444952fd89645569e7452e3226de4a613a4c1793fb2"}, + {file = "rpds_py-0.18.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3dd3cd86e1db5aadd334e011eba4e29d37a104b403e8ca24dcd6703c68ca55b3"}, + {file = "rpds_py-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:05f3d615099bd9b13ecf2fc9cf2d839ad3f20239c678f461c753e93755d629ee"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b2b771b13eee8729a5049c976197ff58a27a3829c018a04341bcf1ae409b2b"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee17cd26b97d537af8f33635ef38be873073d516fd425e80559f4585a7b90c43"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b646bf655b135ccf4522ed43d6902af37d3f5dbcf0da66c769a2b3938b9d8184"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19ba472b9606c36716062c023afa2484d1e4220548751bda14f725a7de17b4f6"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e30ac5e329098903262dc5bdd7e2086e0256aa762cc8b744f9e7bf2a427d3f8"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d58ad6317d188c43750cb76e9deacf6051d0f884d87dc6518e0280438648a9ac"}, + {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e1735502458621921cee039c47318cb90b51d532c2766593be6207eec53e5c4c"}, + {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f5bab211605d91db0e2995a17b5c6ee5edec1270e46223e513eaa20da20076ac"}, + {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2fc24a329a717f9e2448f8cd1f960f9dac4e45b6224d60734edeb67499bab03a"}, + {file = "rpds_py-0.18.1-cp312-none-win32.whl", hash = "sha256:1805d5901779662d599d0e2e4159d8a82c0b05faa86ef9222bf974572286b2b6"}, + {file = "rpds_py-0.18.1-cp312-none-win_amd64.whl", hash = "sha256:720edcb916df872d80f80a1cc5ea9058300b97721efda8651efcd938a9c70a72"}, + {file = "rpds_py-0.18.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:c827576e2fa017a081346dce87d532a5310241648eb3700af9a571a6e9fc7e74"}, + {file = "rpds_py-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:aa3679e751408d75a0b4d8d26d6647b6d9326f5e35c00a7ccd82b78ef64f65f8"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0abeee75434e2ee2d142d650d1e54ac1f8b01e6e6abdde8ffd6eeac6e9c38e20"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed402d6153c5d519a0faf1bb69898e97fb31613b49da27a84a13935ea9164dfc"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:338dee44b0cef8b70fd2ef54b4e09bb1b97fc6c3a58fea5db6cc083fd9fc2724"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7750569d9526199c5b97e5a9f8d96a13300950d910cf04a861d96f4273d5b104"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:607345bd5912aacc0c5a63d45a1f73fef29e697884f7e861094e443187c02be5"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:207c82978115baa1fd8d706d720b4a4d2b0913df1c78c85ba73fe6c5804505f0"}, + {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6d1e42d2735d437e7e80bab4d78eb2e459af48c0a46e686ea35f690b93db792d"}, + {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5463c47c08630007dc0fe99fb480ea4f34a89712410592380425a9b4e1611d8e"}, + {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:06d218939e1bf2ca50e6b0ec700ffe755e5216a8230ab3e87c059ebb4ea06afc"}, + {file = "rpds_py-0.18.1-cp38-none-win32.whl", hash = "sha256:312fe69b4fe1ffbe76520a7676b1e5ac06ddf7826d764cc10265c3b53f96dbe9"}, + {file = "rpds_py-0.18.1-cp38-none-win_amd64.whl", hash = "sha256:9437ca26784120a279f3137ee080b0e717012c42921eb07861b412340f85bae2"}, + {file = "rpds_py-0.18.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:19e515b78c3fc1039dd7da0a33c28c3154458f947f4dc198d3c72db2b6b5dc93"}, + {file = "rpds_py-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7b28c5b066bca9a4eb4e2f2663012debe680f097979d880657f00e1c30875a0"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:673fdbbf668dd958eff750e500495ef3f611e2ecc209464f661bc82e9838991e"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d960de62227635d2e61068f42a6cb6aae91a7fe00fca0e3aeed17667c8a34611"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352a88dc7892f1da66b6027af06a2e7e5d53fe05924cc2cfc56495b586a10b72"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e0ee01ad8260184db21468a6e1c37afa0529acc12c3a697ee498d3c2c4dcaf3"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4c39ad2f512b4041343ea3c7894339e4ca7839ac38ca83d68a832fc8b3748ab"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aaa71ee43a703c321906813bb252f69524f02aa05bf4eec85f0c41d5d62d0f4c"}, + {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6cd8098517c64a85e790657e7b1e509b9fe07487fd358e19431cb120f7d96338"}, + {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4adec039b8e2928983f885c53b7cc4cda8965b62b6596501a0308d2703f8af1b"}, + {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32b7daaa3e9389db3695964ce8e566e3413b0c43e3394c05e4b243a4cd7bef26"}, + {file = "rpds_py-0.18.1-cp39-none-win32.whl", hash = "sha256:2625f03b105328729f9450c8badda34d5243231eef6535f80064d57035738360"}, + {file = "rpds_py-0.18.1-cp39-none-win_amd64.whl", hash = "sha256:bf18932d0003c8c4d51a39f244231986ab23ee057d235a12b2684ea26a353590"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cbfbea39ba64f5e53ae2915de36f130588bba71245b418060ec3330ebf85678e"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a3d456ff2a6a4d2adcdf3c1c960a36f4fd2fec6e3b4902a42a384d17cf4e7a65"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7700936ef9d006b7ef605dc53aa364da2de5a3aa65516a1f3ce73bf82ecfc7ae"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51584acc5916212e1bf45edd17f3a6b05fe0cbb40482d25e619f824dccb679de"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:942695a206a58d2575033ff1e42b12b2aece98d6003c6bc739fbf33d1773b12f"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b906b5f58892813e5ba5c6056d6a5ad08f358ba49f046d910ad992196ea61397"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6f8e3fecca256fefc91bb6765a693d96692459d7d4c644660a9fff32e517843"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7732770412bab81c5a9f6d20aeb60ae943a9b36dcd990d876a773526468e7163"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bd1105b50ede37461c1d51b9698c4f4be6e13e69a908ab7751e3807985fc0346"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:618916f5535784960f3ecf8111581f4ad31d347c3de66d02e728de460a46303c"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:17c6d2155e2423f7e79e3bb18151c686d40db42d8645e7977442170c360194d4"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c4c4c3f878df21faf5fac86eda32671c27889e13570645a9eea0a1abdd50922"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:fab6ce90574645a0d6c58890e9bcaac8d94dff54fb51c69e5522a7358b80ab64"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:531796fb842b53f2695e94dc338929e9f9dbf473b64710c28af5a160b2a8927d"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:740884bc62a5e2bbb31e584f5d23b32320fd75d79f916f15a788d527a5e83644"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:998125738de0158f088aef3cb264a34251908dd2e5d9966774fdab7402edfab7"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2be6e9dd4111d5b31ba3b74d17da54a8319d8168890fbaea4b9e5c3de630ae5"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0cee71bc618cd93716f3c1bf56653740d2d13ddbd47673efa8bf41435a60daa"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2c3caec4ec5cd1d18e5dd6ae5194d24ed12785212a90b37f5f7f06b8bedd7139"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:27bba383e8c5231cd559affe169ca0b96ec78d39909ffd817f28b166d7ddd4d8"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:a888e8bdb45916234b99da2d859566f1e8a1d2275a801bb8e4a9644e3c7e7909"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6031b25fb1b06327b43d841f33842b383beba399884f8228a6bb3df3088485ff"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48c2faaa8adfacefcbfdb5f2e2e7bdad081e5ace8d182e5f4ade971f128e6bb3"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:d85164315bd68c0806768dc6bb0429c6f95c354f87485ee3593c4f6b14def2bd"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6afd80f6c79893cfc0574956f78a0add8c76e3696f2d6a15bca2c66c415cf2d4"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa242ac1ff583e4ec7771141606aafc92b361cd90a05c30d93e343a0c2d82a89"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21be4770ff4e08698e1e8e0bce06edb6ea0626e7c8f560bc08222880aca6a6f"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c45a639e93a0c5d4b788b2613bd637468edd62f8f95ebc6fcc303d58ab3f0a8"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910e71711d1055b2768181efa0a17537b2622afeb0424116619817007f8a2b10"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9bb1f182a97880f6078283b3505a707057c42bf55d8fca604f70dedfdc0772a"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d54f74f40b1f7aaa595a02ff42ef38ca654b1469bef7d52867da474243cc633"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8d2e182c9ee01135e11e9676e9a62dfad791a7a467738f06726872374a83db49"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:636a15acc588f70fda1661234761f9ed9ad79ebed3f2125d44be0862708b666e"}, + {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, ] [[package]] name = "safetensors" -version = "0.4.4" +version = "0.4.3" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.4.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2adb497ada13097f30e386e88c959c0fda855a5f6f98845710f5bb2c57e14f12"}, - {file = "safetensors-0.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7db7fdc2d71fd1444d85ca3f3d682ba2df7d61a637dfc6d80793f439eae264ab"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d4f0eed76b430f009fbefca1a0028ddb112891b03cb556d7440d5cd68eb89a9"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:57d216fab0b5c432aabf7170883d7c11671622bde8bd1436c46d633163a703f6"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d9b76322e49c056bcc819f8bdca37a2daa5a6d42c07f30927b501088db03309"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32f0d1f6243e90ee43bc6ee3e8c30ac5b09ca63f5dd35dbc985a1fc5208c451a"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d464bdc384874601a177375028012a5f177f1505279f9456fea84bbc575c7f"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63144e36209ad8e4e65384dbf2d52dd5b1866986079c00a72335402a38aacdc5"}, - {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:051d5ecd490af7245258000304b812825974d5e56f14a3ff7e1b8b2ba6dc2ed4"}, - {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:51bc8429d9376224cd3cf7e8ce4f208b4c930cd10e515b6ac6a72cbc3370f0d9"}, - {file = "safetensors-0.4.4-cp310-none-win32.whl", hash = "sha256:fb7b54830cee8cf9923d969e2df87ce20e625b1af2fd194222ab902d3adcc29c"}, - {file = "safetensors-0.4.4-cp310-none-win_amd64.whl", hash = "sha256:4b3e8aa8226d6560de8c2b9d5ff8555ea482599c670610758afdc97f3e021e9c"}, - {file = "safetensors-0.4.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bbaa31f2cb49013818bde319232ccd72da62ee40f7d2aa532083eda5664e85ff"}, - {file = "safetensors-0.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9fdcb80f4e9fbb33b58e9bf95e7dbbedff505d1bcd1c05f7c7ce883632710006"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55c14c20be247b8a1aeaf3ab4476265e3ca83096bb8e09bb1a7aa806088def4f"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:949aaa1118660f992dbf0968487b3e3cfdad67f948658ab08c6b5762e90cc8b6"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c11a4ab7debc456326a2bac67f35ee0ac792bcf812c7562a4a28559a5c795e27"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0cea44bba5c5601b297bc8307e4075535b95163402e4906b2e9b82788a2a6df"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9d752c97f6bbe327352f76e5b86442d776abc789249fc5e72eacb49e6916482"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03f2bb92e61b055ef6cc22883ad1ae898010a95730fa988c60a23800eb742c2c"}, - {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf3f91a9328a941acc44eceffd4e1f5f89b030985b2966637e582157173b98"}, - {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:20d218ec2b6899d29d6895419a58b6e44cc5ff8f0cc29fac8d236a8978ab702e"}, - {file = "safetensors-0.4.4-cp311-none-win32.whl", hash = "sha256:8079486118919f600c603536e2490ca37b3dbd3280e3ad6eaacfe6264605ac8a"}, - {file = "safetensors-0.4.4-cp311-none-win_amd64.whl", hash = "sha256:2f8c2eb0615e2e64ee27d478c7c13f51e5329d7972d9e15528d3e4cfc4a08f0d"}, - {file = "safetensors-0.4.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baec5675944b4a47749c93c01c73d826ef7d42d36ba8d0dba36336fa80c76426"}, - {file = "safetensors-0.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f15117b96866401825f3e94543145028a2947d19974429246ce59403f49e77c6"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a13a9caea485df164c51be4eb0c87f97f790b7c3213d635eba2314d959fe929"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b54bc4ca5f9b9bba8cd4fb91c24b2446a86b5ae7f8975cf3b7a277353c3127c"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08332c22e03b651c8eb7bf5fc2de90044f3672f43403b3d9ac7e7e0f4f76495e"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bb62841e839ee992c37bb75e75891c7f4904e772db3691c59daaca5b4ab960e1"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e5b927acc5f2f59547270b0309a46d983edc44be64e1ca27a7fcb0474d6cd67"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a69c71b1ae98a8021a09a0b43363b0143b0ce74e7c0e83cacba691b62655fb8"}, - {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23654ad162c02a5636f0cd520a0310902c4421aab1d91a0b667722a4937cc445"}, - {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0677c109d949cf53756859160b955b2e75b0eefe952189c184d7be30ecf7e858"}, - {file = "safetensors-0.4.4-cp312-none-win32.whl", hash = "sha256:a51d0ddd4deb8871c6de15a772ef40b3dbd26a3c0451bb9e66bc76fc5a784e5b"}, - {file = "safetensors-0.4.4-cp312-none-win_amd64.whl", hash = "sha256:2d065059e75a798bc1933c293b68d04d79b586bb7f8c921e0ca1e82759d0dbb1"}, - {file = "safetensors-0.4.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9d625692578dd40a112df30c02a1adf068027566abd8e6a74893bb13d441c150"}, - {file = "safetensors-0.4.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7cabcf39c81e5b988d0adefdaea2eb9b4fd9bd62d5ed6559988c62f36bfa9a89"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8359bef65f49d51476e9811d59c015f0ddae618ee0e44144f5595278c9f8268c"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a32c662e7df9226fd850f054a3ead0e4213a96a70b5ce37b2d26ba27004e013"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c329a4dcc395364a1c0d2d1574d725fe81a840783dda64c31c5a60fc7d41472c"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:239ee093b1db877c9f8fe2d71331a97f3b9c7c0d3ab9f09c4851004a11f44b65"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd574145d930cf9405a64f9923600879a5ce51d9f315443a5f706374841327b6"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f6784eed29f9e036acb0b7769d9e78a0dc2c72c2d8ba7903005350d817e287a4"}, - {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:65a4a6072436bf0a4825b1c295d248cc17e5f4651e60ee62427a5bcaa8622a7a"}, - {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:df81e3407630de060ae8313da49509c3caa33b1a9415562284eaf3d0c7705f9f"}, - {file = "safetensors-0.4.4-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:e4a0f374200e8443d9746e947ebb346c40f83a3970e75a685ade0adbba5c48d9"}, - {file = "safetensors-0.4.4-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:181fb5f3dee78dae7fd7ec57d02e58f7936498d587c6b7c1c8049ef448c8d285"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb4ac1d8f6b65ec84ddfacd275079e89d9df7c92f95675ba96c4f790a64df6e"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:76897944cd9239e8a70955679b531b9a0619f76e25476e57ed373322d9c2075d"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a9e9d1a27e51a0f69e761a3d581c3af46729ec1c988fa1f839e04743026ae35"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:005ef9fc0f47cb9821c40793eb029f712e97278dae84de91cb2b4809b856685d"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26987dac3752688c696c77c3576f951dbbdb8c57f0957a41fb6f933cf84c0b62"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c05270b290acd8d249739f40d272a64dd597d5a4b90f27d830e538bc2549303c"}, - {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:068d3a33711fc4d93659c825a04480ff5a3854e1d78632cdc8f37fee917e8a60"}, - {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:063421ef08ca1021feea8b46951251b90ae91f899234dd78297cbe7c1db73b99"}, - {file = "safetensors-0.4.4-cp37-none-win32.whl", hash = "sha256:d52f5d0615ea83fd853d4e1d8acf93cc2e0223ad4568ba1e1f6ca72e94ea7b9d"}, - {file = "safetensors-0.4.4-cp37-none-win_amd64.whl", hash = "sha256:88a5ac3280232d4ed8e994cbc03b46a1807ce0aa123867b40c4a41f226c61f94"}, - {file = "safetensors-0.4.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:3467ab511bfe3360967d7dc53b49f272d59309e57a067dd2405b4d35e7dcf9dc"}, - {file = "safetensors-0.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2ab4c96d922e53670ce25fbb9b63d5ea972e244de4fa1dd97b590d9fd66aacef"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87df18fce4440477c3ef1fd7ae17c704a69a74a77e705a12be135ee0651a0c2d"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e5fe345b2bc7d88587149ac11def1f629d2671c4c34f5df38aed0ba59dc37f8"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9f1a3e01dce3cd54060791e7e24588417c98b941baa5974700eeb0b8eb65b0a0"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c6bf35e9a8998d8339fd9a05ac4ce465a4d2a2956cc0d837b67c4642ed9e947"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:166c0c52f6488b8538b2a9f3fbc6aad61a7261e170698779b371e81b45f0440d"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87e9903b8668a16ef02c08ba4ebc91e57a49c481e9b5866e31d798632805014b"}, - {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a9c421153aa23c323bd8483d4155b4eee82c9a50ac11cccd83539104a8279c64"}, - {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a4b8617499b2371c7353302c5116a7e0a3a12da66389ce53140e607d3bf7b3d3"}, - {file = "safetensors-0.4.4-cp38-none-win32.whl", hash = "sha256:c6280f5aeafa1731f0a3709463ab33d8e0624321593951aefada5472f0b313fd"}, - {file = "safetensors-0.4.4-cp38-none-win_amd64.whl", hash = "sha256:6ceed6247fc2d33b2a7b7d25d8a0fe645b68798856e0bc7a9800c5fd945eb80f"}, - {file = "safetensors-0.4.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5cf6c6f6193797372adf50c91d0171743d16299491c75acad8650107dffa9269"}, - {file = "safetensors-0.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:419010156b914a3e5da4e4adf992bee050924d0fe423c4b329e523e2c14c3547"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88f6fd5a5c1302ce79993cc5feeadcc795a70f953c762544d01fb02b2db4ea33"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d468cffb82d90789696d5b4d8b6ab8843052cba58a15296691a7a3df55143cd2"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9353c2af2dd467333d4850a16edb66855e795561cd170685178f706c80d2c71e"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:83c155b4a33368d9b9c2543e78f2452090fb030c52401ca608ef16fa58c98353"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9850754c434e636ce3dc586f534bb23bcbd78940c304775bee9005bf610e98f1"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:275f500b4d26f67b6ec05629a4600645231bd75e4ed42087a7c1801bff04f4b3"}, - {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5c2308de665b7130cd0e40a2329278226e4cf083f7400c51ca7e19ccfb3886f3"}, - {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e06a9ebc8656e030ccfe44634f2a541b4b1801cd52e390a53ad8bacbd65f8518"}, - {file = "safetensors-0.4.4-cp39-none-win32.whl", hash = "sha256:ef73df487b7c14b477016947c92708c2d929e1dee2bacdd6fff5a82ed4539537"}, - {file = "safetensors-0.4.4-cp39-none-win_amd64.whl", hash = "sha256:83d054818a8d1198d8bd8bc3ea2aac112a2c19def2bf73758321976788706398"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1d1f34c71371f0e034004a0b583284b45d233dd0b5f64a9125e16b8a01d15067"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a8043a33d58bc9b30dfac90f75712134ca34733ec3d8267b1bd682afe7194f5"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8db8f0c59c84792c12661f8efa85de160f80efe16b87a9d5de91b93f9e0bce3c"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfc1fc38e37630dd12d519bdec9dcd4b345aec9930bb9ce0ed04461f49e58b52"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5c9d86d9b13b18aafa88303e2cd21e677f5da2a14c828d2c460fe513af2e9a5"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:43251d7f29a59120a26f5a0d9583b9e112999e500afabcfdcb91606d3c5c89e3"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2c42e9b277513b81cf507e6121c7b432b3235f980cac04f39f435b7902857f91"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3daacc9a4e3f428a84dd56bf31f20b768eb0b204af891ed68e1f06db9edf546f"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:218bbb9b883596715fc9997bb42470bf9f21bb832c3b34c2bf744d6fa8f2bbba"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bd5efc26b39f7fc82d4ab1d86a7f0644c8e34f3699c33f85bfa9a717a030e1b"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56ad9776b65d8743f86698a1973292c966cf3abff627efc44ed60e66cc538ddd"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:30f23e6253c5f43a809dea02dc28a9f5fa747735dc819f10c073fe1b605e97d4"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:5512078d00263de6cb04e9d26c9ae17611098f52357fea856213e38dc462f81f"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b96c3d9266439d17f35fc2173111d93afc1162f168e95aed122c1ca517b1f8f1"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:08d464aa72a9a13826946b4fb9094bb4b16554bbea2e069e20bd903289b6ced9"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:210160816d5a36cf41f48f38473b6f70d7bcb4b0527bedf0889cc0b4c3bb07db"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb276a53717f2bcfb6df0bcf284d8a12069002508d4c1ca715799226024ccd45"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a2c28c6487f17d8db0089e8b2cdc13de859366b94cc6cdc50e1b0a4147b56551"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7915f0c60e4e6e65d90f136d85dd3b429ae9191c36b380e626064694563dbd9f"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:00eea99ae422fbfa0b46065acbc58b46bfafadfcec179d4b4a32d5c45006af6c"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bb1ed4fcb0b3c2f3ea2c5767434622fe5d660e5752f21ac2e8d737b1e5e480bb"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:73fc9a0a4343188bdb421783e600bfaf81d0793cd4cce6bafb3c2ed567a74cd5"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c37e6b714200824c73ca6eaf007382de76f39466a46e97558b8dc4cf643cfbf"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75698c5c5c542417ac4956acfc420f7d4a2396adca63a015fd66641ea751759"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca1a209157f242eb183e209040097118472e169f2e069bfbd40c303e24866543"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:177f2b60a058f92a3cec7a1786c9106c29eca8987ecdfb79ee88126e5f47fa31"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ee9622e84fe6e4cd4f020e5fda70d6206feff3157731df7151d457fdae18e541"}, - {file = "safetensors-0.4.4.tar.gz", hash = "sha256:5fe3e9b705250d0172ed4e100a811543108653fb2b66b9e702a088ad03772a07"}, + {file = "safetensors-0.4.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:dcf5705cab159ce0130cd56057f5f3425023c407e170bca60b4868048bae64fd"}, + {file = "safetensors-0.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bb4f8c5d0358a31e9a08daeebb68f5e161cdd4018855426d3f0c23bb51087055"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70a5319ef409e7f88686a46607cbc3c428271069d8b770076feaf913664a07ac"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb9c65bd82f9ef3ce4970dc19ee86be5f6f93d032159acf35e663c6bea02b237"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:edb5698a7bc282089f64c96c477846950358a46ede85a1c040e0230344fdde10"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:efcc860be094b8d19ac61b452ec635c7acb9afa77beb218b1d7784c6d41fe8ad"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d88b33980222085dd6001ae2cad87c6068e0991d4f5ccf44975d216db3b57376"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5fc6775529fb9f0ce2266edd3e5d3f10aab068e49f765e11f6f2a63b5367021d"}, + {file = "safetensors-0.4.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9c6ad011c1b4e3acff058d6b090f1da8e55a332fbf84695cf3100c649cc452d1"}, + {file = "safetensors-0.4.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c496c5401c1b9c46d41a7688e8ff5b0310a3b9bae31ce0f0ae870e1ea2b8caf"}, + {file = "safetensors-0.4.3-cp310-none-win32.whl", hash = "sha256:38e2a8666178224a51cca61d3cb4c88704f696eac8f72a49a598a93bbd8a4af9"}, + {file = "safetensors-0.4.3-cp310-none-win_amd64.whl", hash = "sha256:393e6e391467d1b2b829c77e47d726f3b9b93630e6a045b1d1fca67dc78bf632"}, + {file = "safetensors-0.4.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:22f3b5d65e440cec0de8edaa672efa888030802e11c09b3d6203bff60ebff05a"}, + {file = "safetensors-0.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c4fa560ebd4522adddb71dcd25d09bf211b5634003f015a4b815b7647d62ebe"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9afd5358719f1b2cf425fad638fc3c887997d6782da317096877e5b15b2ce93"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d8c5093206ef4b198600ae484230402af6713dab1bd5b8e231905d754022bec7"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0b2104df1579d6ba9052c0ae0e3137c9698b2d85b0645507e6fd1813b70931a"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8cf18888606dad030455d18f6c381720e57fc6a4170ee1966adb7ebc98d4d6a3"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bf4f9d6323d9f86eef5567eabd88f070691cf031d4c0df27a40d3b4aaee755b"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:585c9ae13a205807b63bef8a37994f30c917ff800ab8a1ca9c9b5d73024f97ee"}, + {file = "safetensors-0.4.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faefeb3b81bdfb4e5a55b9bbdf3d8d8753f65506e1d67d03f5c851a6c87150e9"}, + {file = "safetensors-0.4.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:befdf0167ad626f22f6aac6163477fcefa342224a22f11fdd05abb3995c1783c"}, + {file = "safetensors-0.4.3-cp311-none-win32.whl", hash = "sha256:a7cef55929dcbef24af3eb40bedec35d82c3c2fa46338bb13ecf3c5720af8a61"}, + {file = "safetensors-0.4.3-cp311-none-win_amd64.whl", hash = "sha256:840b7ac0eff5633e1d053cc9db12fdf56b566e9403b4950b2dc85393d9b88d67"}, + {file = "safetensors-0.4.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:22d21760dc6ebae42e9c058d75aa9907d9f35e38f896e3c69ba0e7b213033856"}, + {file = "safetensors-0.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d22c1a10dff3f64d0d68abb8298a3fd88ccff79f408a3e15b3e7f637ef5c980"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1648568667f820b8c48317c7006221dc40aced1869908c187f493838a1362bc"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:446e9fe52c051aeab12aac63d1017e0f68a02a92a027b901c4f8e931b24e5397"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fef5d70683643618244a4f5221053567ca3e77c2531e42ad48ae05fae909f542"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a1f4430cc0c9d6afa01214a4b3919d0a029637df8e09675ceef1ca3f0dfa0df"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d603846a8585b9432a0fd415db1d4c57c0f860eb4aea21f92559ff9902bae4d"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a844cdb5d7cbc22f5f16c7e2a0271170750763c4db08381b7f696dbd2c78a361"}, + {file = "safetensors-0.4.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:88887f69f7a00cf02b954cdc3034ffb383b2303bc0ab481d4716e2da51ddc10e"}, + {file = "safetensors-0.4.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ee463219d9ec6c2be1d331ab13a8e0cd50d2f32240a81d498266d77d07b7e71e"}, + {file = "safetensors-0.4.3-cp312-none-win32.whl", hash = "sha256:d0dd4a1db09db2dba0f94d15addc7e7cd3a7b0d393aa4c7518c39ae7374623c3"}, + {file = "safetensors-0.4.3-cp312-none-win_amd64.whl", hash = "sha256:d14d30c25897b2bf19b6fb5ff7e26cc40006ad53fd4a88244fdf26517d852dd7"}, + {file = "safetensors-0.4.3-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d1456f814655b224d4bf6e7915c51ce74e389b413be791203092b7ff78c936dd"}, + {file = "safetensors-0.4.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:455d538aa1aae4a8b279344a08136d3f16334247907b18a5c3c7fa88ef0d3c46"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf476bca34e1340ee3294ef13e2c625833f83d096cfdf69a5342475602004f95"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02ef3a24face643456020536591fbd3c717c5abaa2737ec428ccbbc86dffa7a4"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7de32d0d34b6623bb56ca278f90db081f85fb9c5d327e3c18fd23ac64f465768"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a0deb16a1d3ea90c244ceb42d2c6c276059616be21a19ac7101aa97da448faf"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c59d51f182c729f47e841510b70b967b0752039f79f1de23bcdd86462a9b09ee"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1f598b713cc1a4eb31d3b3203557ac308acf21c8f41104cdd74bf640c6e538e3"}, + {file = "safetensors-0.4.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5757e4688f20df083e233b47de43845d1adb7e17b6cf7da5f8444416fc53828d"}, + {file = "safetensors-0.4.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fe746d03ed8d193674a26105e4f0fe6c726f5bb602ffc695b409eaf02f04763d"}, + {file = "safetensors-0.4.3-cp37-none-win32.whl", hash = "sha256:0d5ffc6a80f715c30af253e0e288ad1cd97a3d0086c9c87995e5093ebc075e50"}, + {file = "safetensors-0.4.3-cp37-none-win_amd64.whl", hash = "sha256:a11c374eb63a9c16c5ed146457241182f310902bd2a9c18255781bb832b6748b"}, + {file = "safetensors-0.4.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:b1e31be7945f66be23f4ec1682bb47faa3df34cb89fc68527de6554d3c4258a4"}, + {file = "safetensors-0.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:03a4447c784917c9bf01d8f2ac5080bc15c41692202cd5f406afba16629e84d6"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d244bcafeb1bc06d47cfee71727e775bca88a8efda77a13e7306aae3813fa7e4"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53c4879b9c6bd7cd25d114ee0ef95420e2812e676314300624594940a8d6a91f"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74707624b81f1b7f2b93f5619d4a9f00934d5948005a03f2c1845ffbfff42212"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d52c958dc210265157573f81d34adf54e255bc2b59ded6218500c9b15a750eb"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f9568f380f513a60139971169c4a358b8731509cc19112369902eddb33faa4d"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0d9cd8e1560dfc514b6d7859247dc6a86ad2f83151a62c577428d5102d872721"}, + {file = "safetensors-0.4.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:89f9f17b0dacb913ed87d57afbc8aad85ea42c1085bd5de2f20d83d13e9fc4b2"}, + {file = "safetensors-0.4.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1139eb436fd201c133d03c81209d39ac57e129f5e74e34bb9ab60f8d9b726270"}, + {file = "safetensors-0.4.3-cp38-none-win32.whl", hash = "sha256:d9c289f140a9ae4853fc2236a2ffc9a9f2d5eae0cb673167e0f1b8c18c0961ac"}, + {file = "safetensors-0.4.3-cp38-none-win_amd64.whl", hash = "sha256:622afd28968ef3e9786562d352659a37de4481a4070f4ebac883f98c5836563e"}, + {file = "safetensors-0.4.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8651c7299cbd8b4161a36cd6a322fa07d39cd23535b144d02f1c1972d0c62f3c"}, + {file = "safetensors-0.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e375d975159ac534c7161269de24ddcd490df2157b55c1a6eeace6cbb56903f0"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:084fc436e317f83f7071fc6a62ca1c513b2103db325cd09952914b50f51cf78f"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:41a727a7f5e6ad9f1db6951adee21bbdadc632363d79dc434876369a17de6ad6"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7dbbde64b6c534548696808a0e01276d28ea5773bc9a2dfb97a88cd3dffe3df"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bbae3b4b9d997971431c346edbfe6e41e98424a097860ee872721e176040a893"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01e4b22e3284cd866edeabe4f4d896229495da457229408d2e1e4810c5187121"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dd37306546b58d3043eb044c8103a02792cc024b51d1dd16bd3dd1f334cb3ed"}, + {file = "safetensors-0.4.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d8815b5e1dac85fc534a97fd339e12404db557878c090f90442247e87c8aeaea"}, + {file = "safetensors-0.4.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e011cc162503c19f4b1fd63dfcddf73739c7a243a17dac09b78e57a00983ab35"}, + {file = "safetensors-0.4.3-cp39-none-win32.whl", hash = "sha256:01feb3089e5932d7e662eda77c3ecc389f97c0883c4a12b5cfdc32b589a811c3"}, + {file = "safetensors-0.4.3-cp39-none-win_amd64.whl", hash = "sha256:3f9cdca09052f585e62328c1c2923c70f46814715c795be65f0b93f57ec98a02"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1b89381517891a7bb7d1405d828b2bf5d75528299f8231e9346b8eba092227f9"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:cd6fff9e56df398abc5866b19a32124815b656613c1c5ec0f9350906fd798aac"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:840caf38d86aa7014fe37ade5d0d84e23dcfbc798b8078015831996ecbc206a3"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9650713b2cfa9537a2baf7dd9fee458b24a0aaaa6cafcea8bdd5fb2b8efdc34"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e4119532cd10dba04b423e0f86aecb96cfa5a602238c0aa012f70c3a40c44b50"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e066e8861eef6387b7c772344d1fe1f9a72800e04ee9a54239d460c400c72aab"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:90964917f5b0fa0fa07e9a051fbef100250c04d150b7026ccbf87a34a54012e0"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c41e1893d1206aa7054029681778d9a58b3529d4c807002c156d58426c225173"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae7613a119a71a497d012ccc83775c308b9c1dab454806291427f84397d852fd"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9bac020faba7f5dc481e881b14b6425265feabb5bfc552551d21189c0eddc3"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:420a98f593ff9930f5822560d14c395ccbc57342ddff3b463bc0b3d6b1951550"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f5e6883af9a68c0028f70a4c19d5a6ab6238a379be36ad300a22318316c00cb0"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:cdd0a3b5da66e7f377474599814dbf5cbf135ff059cc73694de129b58a5e8a2c"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9bfb92f82574d9e58401d79c70c716985dc049b635fef6eecbb024c79b2c46ad"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:3615a96dd2dcc30eb66d82bc76cda2565f4f7bfa89fcb0e31ba3cea8a1a9ecbb"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:868ad1b6fc41209ab6bd12f63923e8baeb1a086814cb2e81a65ed3d497e0cf8f"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ffba80aa49bd09195145a7fd233a7781173b422eeb995096f2b30591639517"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0acbe31340ab150423347e5b9cc595867d814244ac14218932a5cf1dd38eb39"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19bbdf95de2cf64f25cd614c5236c8b06eb2cfa47cbf64311f4b5d80224623a3"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b852e47eb08475c2c1bd8131207b405793bfc20d6f45aff893d3baaad449ed14"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5d07cbca5b99babb692d76d8151bec46f461f8ad8daafbfd96b2fca40cadae65"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1ab6527a20586d94291c96e00a668fa03f86189b8a9defa2cdd34a1a01acc7d5"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02318f01e332cc23ffb4f6716e05a492c5f18b1d13e343c49265149396284a44"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec4b52ce9a396260eb9731eb6aea41a7320de22ed73a1042c2230af0212758ce"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:018b691383026a2436a22b648873ed11444a364324e7088b99cd2503dd828400"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:309b10dbcab63269ecbf0e2ca10ce59223bb756ca5d431ce9c9eeabd446569da"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b277482120df46e27a58082df06a15aebda4481e30a1c21eefd0921ae7e03f65"}, + {file = "safetensors-0.4.3.tar.gz", hash = "sha256:2f85fc50c4e07a21e95c24e07460fe6f7e2859d0ce88092838352b798ce711c2"}, ] [package.extras] @@ -2819,51 +2655,6 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] -[[package]] -name = "scikit-learn" -version = "1.5.1" -description = "A set of python modules for machine learning and data mining" -optional = false -python-versions = ">=3.9" -files = [ - {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"}, - {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"}, - {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"}, - {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"}, - {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"}, - {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"}, - {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"}, - {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"}, - {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"}, - {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"}, - {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"}, - {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"}, - {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"}, - {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"}, - {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"}, - {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"}, - {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"}, - {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"}, - {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"}, - {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"}, - {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"}, -] - -[package.dependencies] -joblib = ">=1.2.0" -numpy = ">=1.19.5" -scipy = ">=1.6.0" -threadpoolctl = ">=3.1.0" - -[package.extras] -benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] -build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] -examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] -install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] -maintenance = ["conda-lock (==2.5.6)"] -tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] - [[package]] name = "scipy" version = "1.13.1" @@ -2906,33 +2697,6 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pyde doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] -[[package]] -name = "sentence-transformers" -version = "3.0.1" -description = "Multilingual text embeddings" -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "sentence_transformers-3.0.1-py3-none-any.whl", hash = "sha256:01050cc4053c49b9f5b78f6980b5a72db3fd3a0abb9169b1792ac83875505ee6"}, - {file = "sentence_transformers-3.0.1.tar.gz", hash = "sha256:8a3d2c537cc4d1014ccc20ac92be3d6135420a3bc60ae29a3a8a9b4bb35fbff6"}, -] - -[package.dependencies] -accelerate = {version = ">=0.20.3", optional = true, markers = "extra == \"train\""} -datasets = {version = "*", optional = true, markers = "extra == \"train\""} -huggingface-hub = ">=0.15.1" -numpy = "*" -Pillow = "*" -scikit-learn = "*" -scipy = "*" -torch = ">=1.11.0" -tqdm = "*" -transformers = ">=4.34.0,<5.0.0" - -[package.extras] -dev = ["accelerate (>=0.20.3)", "datasets", "pre-commit", "pytest", "ruff (>=0.3.0)"] -train = ["accelerate (>=0.20.3)", "datasets"] - [[package]] name = "sentencepiece" version = "0.1.99" @@ -2989,25 +2753,24 @@ files = [ [[package]] name = "setuptools" -version = "73.0.1" +version = "70.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-73.0.1-py3-none-any.whl", hash = "sha256:b208925fcb9f7af924ed2dc04708ea89791e24bde0d3020b27df0e116088b34e"}, - {file = "setuptools-73.0.1.tar.gz", hash = "sha256:d59a3e788ab7e012ab2c4baed1b376da6366883ee20d7a5fc426816e3d7b1193"}, + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, ] [package.extras] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, @@ -3016,30 +2779,40 @@ files = [ [[package]] name = "sympy" -version = "1.13.2" +version = "1.12" description = "Computer algebra system (CAS) in Python" -optional = false +optional = true python-versions = ">=3.8" files = [ - {file = "sympy-1.13.2-py3-none-any.whl", hash = "sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9"}, - {file = "sympy-1.13.2.tar.gz", hash = "sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13"}, + {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, + {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] [package.dependencies] -mpmath = ">=1.1.0,<1.4" - -[package.extras] -dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +mpmath = ">=0.19" [[package]] -name = "threadpoolctl" -version = "3.5.0" -description = "threadpoolctl" -optional = false -python-versions = ">=3.8" +name = "tbb" +version = "2021.12.0" +description = "Intel® oneAPI Threading Building Blocks (oneTBB)" +optional = true +python-versions = "*" files = [ - {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, - {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, + {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, + {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, +] + +[[package]] +name = "texttable" +version = "1.7.0" +description = "module to create simple ASCII tables" +optional = true +python-versions = "*" +files = [ + {file = "texttable-1.7.0-py2.py3-none-any.whl", hash = "sha256:72227d592c82b3d7f672731ae73e4d1f88cd8e2ef5b075a7a7f01a23a3743917"}, + {file = "texttable-1.7.0.tar.gz", hash = "sha256:2d2068fb55115807d3ac77a4ca68fa48803e84ebb0ee2340f858107a36522638"}, ] [[package]] @@ -3172,43 +2945,44 @@ files = [ [[package]] name = "torch" -version = "2.4.0" +version = "2.3.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -optional = false +optional = true python-versions = ">=3.8.0" files = [ - {file = "torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:4ed94583e244af51d6a8d28701ca5a9e02d1219e782f5a01dd401f90af17d8ac"}, - {file = "torch-2.4.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c4ca297b7bd58b506bfd6e78ffd14eb97c0e7797dcd7965df62f50bb575d8954"}, - {file = "torch-2.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2497cbc7b3c951d69b276ca51fe01c2865db67040ac67f5fc20b03e41d16ea4a"}, - {file = "torch-2.4.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:685418ab93730efbee71528821ff54005596970dd497bf03c89204fb7e3f71de"}, - {file = "torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e743adadd8c8152bb8373543964551a7cb7cc20ba898dc8f9c0cdbe47c283de0"}, - {file = "torch-2.4.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:7334325c0292cbd5c2eac085f449bf57d3690932eac37027e193ba775703c9e6"}, - {file = "torch-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:97730014da4c57ffacb3c09298c6ce05400606e890bd7a05008d13dd086e46b1"}, - {file = "torch-2.4.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f169b4ea6dc93b3a33319611fcc47dc1406e4dd539844dcbd2dec4c1b96e166d"}, - {file = "torch-2.4.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:997084a0f9784d2a89095a6dc67c7925e21bf25dea0b3d069b41195016ccfcbb"}, - {file = "torch-2.4.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc3988e8b36d1e8b998d143255d9408d8c75da4ab6dd0dcfd23b623dfb0f0f57"}, - {file = "torch-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3374128bbf7e62cdaed6c237bfd39809fbcfaa576bee91e904706840c3f2195c"}, - {file = "torch-2.4.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:91aaf00bfe1ffa44dc5b52809d9a95129fca10212eca3ac26420eb11727c6288"}, - {file = "torch-2.4.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cc30457ea5489c62747d3306438af00c606b509d78822a88f804202ba63111ed"}, - {file = "torch-2.4.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a046491aaf96d1215e65e1fa85911ef2ded6d49ea34c8df4d0638879f2402eef"}, - {file = "torch-2.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:688eec9240f3ce775f22e1e1a5ab9894f3d5fe60f3f586deb7dbd23a46a83916"}, - {file = "torch-2.4.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:3af4de2a618fb065e78404c4ba27a818a7b7957eaeff28c6c66ce7fb504b68b8"}, - {file = "torch-2.4.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:618808d3f610d5f180e47a697d4ec90b810953bb1e020f424b2ac7fb0884b545"}, - {file = "torch-2.4.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed765d232d23566052ba83632ec73a4fccde00b4c94ad45d63b471b09d63b7a7"}, - {file = "torch-2.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2feb98ac470109472fb10dfef38622a7ee08482a16c357863ebc7bc7db7c8f7"}, - {file = "torch-2.4.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:8940fc8b97a4c61fdb5d46a368f21f4a3a562a17879e932eb51a5ec62310cb31"}, + {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, + {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, + {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, + {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, + {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, + {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, + {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, + {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, + {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, + {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, + {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, + {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, + {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, + {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, + {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, + {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, + {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, + {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, + {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, + {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, ] [package.dependencies] filelock = "*" fsspec = "*" jinja2 = "*" +mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} @@ -3216,22 +2990,22 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} +triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.11.0)"] +optree = ["optree (>=0.9.1)"] [[package]] name = "tqdm" -version = "4.66.5" +version = "4.66.4" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, - {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, + {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, + {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, ] [package.dependencies] @@ -3245,41 +3019,38 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.43.4" +version = "4.41.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.43.4-py3-none-any.whl", hash = "sha256:d2202ed201e0c44f80de8d753a19f6164187754630bc1f915661b9511d61c773"}, - {file = "transformers-4.43.4.tar.gz", hash = "sha256:b62288990a65ed9bfb79191e04dbb76c9376834ae6e0dd911320a2ced63324fe"}, + {file = "transformers-4.41.1-py3-none-any.whl", hash = "sha256:f0680e0b1a01067eccd11f62f0522409422c7d6f91d532fe0f50b136a406129d"}, + {file = "transformers-4.41.1.tar.gz", hash = "sha256:fa859e4c66f0896633a3bf534e0d9a29a9a88478a49f94c5d8270537dc61cc42"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.23.2,<1.0" +huggingface-hub = ">=0.23.0,<1.0" numpy = ">=1.17" packaging = ">=20.0" -protobuf = {version = "*", optional = true, markers = "extra == \"sentencepiece\""} pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.4.1" -sentencepiece = {version = ">=0.1.91,<0.1.92 || >0.1.92", optional = true, markers = "extra == \"sentencepiece\""} tokenizers = ">=0.19,<0.20" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -benchmark = ["optimum-benchmark (>=0.2.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] @@ -3290,41 +3061,41 @@ natten = ["natten (>=0.14.6,<0.15.0)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.4.4)", "urllib3 (<2.0.0)"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] -ruff = ["ruff (==0.4.4)"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -timm = ["timm (<=0.9.16)"] +timm = ["timm"] tokenizers = ["tokenizers (>=0.19,<0.20)"] torch = ["accelerate (>=0.21.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "3.0.0" +version = "2.3.0" description = "A language and compiler for custom Deep Learning operations" -optional = false +optional = true python-versions = "*" files = [ - {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, - {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, - {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, - {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, - {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, + {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, + {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, + {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, + {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, + {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, ] [package.dependencies] @@ -3332,18 +3103,18 @@ filelock = "*" [package.extras] build = ["cmake (>=3.20)", "lit"] -tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] -tutorials = ["matplotlib", "pandas", "tabulate"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "typer" -version = "0.7.0" +version = "0.6.1" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.6" files = [ - {file = "typer-0.7.0-py3-none-any.whl", hash = "sha256:b5e704f4e48ec263de1c0b3a2387cd405a13767d2f907f44c1a08cbad96f606d"}, - {file = "typer-0.7.0.tar.gz", hash = "sha256:ff797846578a9f2a201b53442aedeb543319466870fbe1c701eab66dd7681165"}, + {file = "typer-0.6.1-py3-none-any.whl", hash = "sha256:54b19e5df18654070a82f8c2aa1da456a4ac16a2a83e6dcd9f170e291c56338e"}, + {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, ] [package.dependencies] @@ -3352,25 +3123,25 @@ click = ">=7.1.1,<9.0.0" [package.extras] all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] -doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] -test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] [[package]] name = "typing-extensions" -version = "4.12.2" +version = "4.12.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, - {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, + {file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, + {file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, ] [[package]] name = "tzdata" version = "2024.1" description = "Provider of IANA time zone data" -optional = false +optional = true python-versions = ">=2" files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, @@ -3379,13 +3150,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.2" +version = "2.2.1" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, - {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, + {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, + {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, ] [package.extras] @@ -3489,141 +3260,126 @@ files = [ [[package]] name = "xxhash" -version = "3.5.0" +version = "3.4.1" description = "Python binding for xxHash" -optional = false +optional = true python-versions = ">=3.7" files = [ - {file = "xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212"}, - {file = "xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520"}, - {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c5d3e570ef46adaf93fc81b44aca6002b5a4d8ca11bd0580c07eac537f36680"}, - {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb29a034301e2982df8b1fe6328a84f4b676106a13e9135a0d7e0c3e9f806da"}, - {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0d307d27099bb0cbeea7260eb39ed4fdb99c5542e21e94bb6fd29e49c57a23"}, - {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0342aafd421795d740e514bc9858ebddfc705a75a8c5046ac56d85fe97bf196"}, - {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dbbd9892c5ebffeca1ed620cf0ade13eb55a0d8c84e0751a6653adc6ac40d0c"}, - {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4cc2d67fdb4d057730c75a64c5923abfa17775ae234a71b0200346bfb0a7f482"}, - {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ec28adb204b759306a3d64358a5e5c07d7b1dd0ccbce04aa76cb9377b7b70296"}, - {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1328f6d8cca2b86acb14104e381225a3d7b42c92c4b86ceae814e5c400dbb415"}, - {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8d47ebd9f5d9607fd039c1fbf4994e3b071ea23eff42f4ecef246ab2b7334198"}, - {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b96d559e0fcddd3343c510a0fe2b127fbff16bf346dd76280b82292567523442"}, - {file = "xxhash-3.5.0-cp310-cp310-win32.whl", hash = "sha256:61c722ed8d49ac9bc26c7071eeaa1f6ff24053d553146d5df031802deffd03da"}, - {file = "xxhash-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9bed5144c6923cc902cd14bb8963f2d5e034def4486ab0bbe1f58f03f042f9a9"}, - {file = "xxhash-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:893074d651cf25c1cc14e3bea4fceefd67f2921b1bb8e40fcfeba56820de80c6"}, - {file = "xxhash-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02c2e816896dc6f85922ced60097bcf6f008dedfc5073dcba32f9c8dd786f3c1"}, - {file = "xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6027dcd885e21581e46d3c7f682cfb2b870942feeed58a21c29583512c3f09f8"}, - {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1308fa542bbdbf2fa85e9e66b1077eea3a88bef38ee8a06270b4298a7a62a166"}, - {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c28b2fdcee797e1c1961cd3bcd3d545cab22ad202c846235197935e1df2f8ef7"}, - {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:924361811732ddad75ff23e90efd9ccfda4f664132feecb90895bade6a1b4623"}, - {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89997aa1c4b6a5b1e5b588979d1da048a3c6f15e55c11d117a56b75c84531f5a"}, - {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:685c4f4e8c59837de103344eb1c8a3851f670309eb5c361f746805c5471b8c88"}, - {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbd2ecfbfee70bc1a4acb7461fa6af7748ec2ab08ac0fa298f281c51518f982c"}, - {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:25b5a51dc3dfb20a10833c8eee25903fd2e14059e9afcd329c9da20609a307b2"}, - {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a8fb786fb754ef6ff8c120cb96629fb518f8eb5a61a16aac3a979a9dbd40a084"}, - {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a905ad00ad1e1c34fe4e9d7c1d949ab09c6fa90c919860c1534ff479f40fd12d"}, - {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:963be41bcd49f53af6d795f65c0da9b4cc518c0dd9c47145c98f61cb464f4839"}, - {file = "xxhash-3.5.0-cp311-cp311-win32.whl", hash = "sha256:109b436096d0a2dd039c355fa3414160ec4d843dfecc64a14077332a00aeb7da"}, - {file = "xxhash-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b702f806693201ad6c0a05ddbbe4c8f359626d0b3305f766077d51388a6bac58"}, - {file = "xxhash-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:c4dcb4120d0cc3cc448624147dba64e9021b278c63e34a38789b688fd0da9bf3"}, - {file = "xxhash-3.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:14470ace8bd3b5d51318782cd94e6f94431974f16cb3b8dc15d52f3b69df8e00"}, - {file = "xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59aa1203de1cb96dbeab595ded0ad0c0056bb2245ae11fac11c0ceea861382b9"}, - {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08424f6648526076e28fae6ea2806c0a7d504b9ef05ae61d196d571e5c879c84"}, - {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61a1ff00674879725b194695e17f23d3248998b843eb5e933007ca743310f793"}, - {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2f2c61bee5844d41c3eb015ac652a0229e901074951ae48581d58bfb2ba01be"}, - {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d32a592cac88d18cc09a89172e1c32d7f2a6e516c3dfde1b9adb90ab5df54a6"}, - {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70dabf941dede727cca579e8c205e61121afc9b28516752fd65724be1355cc90"}, - {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e5d0ddaca65ecca9c10dcf01730165fd858533d0be84c75c327487c37a906a27"}, - {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e5b5e16c5a480fe5f59f56c30abdeba09ffd75da8d13f6b9b6fd224d0b4d0a2"}, - {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149b7914451eb154b3dfaa721315117ea1dac2cc55a01bfbd4df7c68c5dd683d"}, - {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:eade977f5c96c677035ff39c56ac74d851b1cca7d607ab3d8f23c6b859379cab"}, - {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa9f547bd98f5553d03160967866a71056a60960be00356a15ecc44efb40ba8e"}, - {file = "xxhash-3.5.0-cp312-cp312-win32.whl", hash = "sha256:f7b58d1fd3551b8c80a971199543379be1cee3d0d409e1f6d8b01c1a2eebf1f8"}, - {file = "xxhash-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0cafd3a2af231b4e113fba24a65d7922af91aeb23774a8b78228e6cd785e3e"}, - {file = "xxhash-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:586886c7e89cb9828bcd8a5686b12e161368e0064d040e225e72607b43858ba2"}, - {file = "xxhash-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37889a0d13b0b7d739cfc128b1c902f04e32de17b33d74b637ad42f1c55101f6"}, - {file = "xxhash-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97a662338797c660178e682f3bc180277b9569a59abfb5925e8620fba00b9fc5"}, - {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f85e0108d51092bdda90672476c7d909c04ada6923c14ff9d913c4f7dc8a3bc"}, - {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2fd827b0ba763ac919440042302315c564fdb797294d86e8cdd4578e3bc7f3"}, - {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82085c2abec437abebf457c1d12fccb30cc8b3774a0814872511f0f0562c768c"}, - {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07fda5de378626e502b42b311b049848c2ef38784d0d67b6f30bb5008642f8eb"}, - {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c279f0d2b34ef15f922b77966640ade58b4ccdfef1c4d94b20f2a364617a493f"}, - {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89e66ceed67b213dec5a773e2f7a9e8c58f64daeb38c7859d8815d2c89f39ad7"}, - {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bcd51708a633410737111e998ceb3b45d3dbc98c0931f743d9bb0a209033a326"}, - {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3ff2c0a34eae7df88c868be53a8dd56fbdf592109e21d4bfa092a27b0bf4a7bf"}, - {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4e28503dccc7d32e0b9817aa0cbfc1f45f563b2c995b7a66c4c8a0d232e840c7"}, - {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6c50017518329ed65a9e4829154626f008916d36295b6a3ba336e2458824c8c"}, - {file = "xxhash-3.5.0-cp313-cp313-win32.whl", hash = "sha256:53a068fe70301ec30d868ece566ac90d873e3bb059cf83c32e76012c889b8637"}, - {file = "xxhash-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:80babcc30e7a1a484eab952d76a4f4673ff601f54d5142c26826502740e70b43"}, - {file = "xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b"}, - {file = "xxhash-3.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6e5f70f6dca1d3b09bccb7daf4e087075ff776e3da9ac870f86ca316736bb4aa"}, - {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e76e83efc7b443052dd1e585a76201e40b3411fe3da7af4fe434ec51b2f163b"}, - {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33eac61d0796ca0591f94548dcfe37bb193671e0c9bcf065789b5792f2eda644"}, - {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ec70a89be933ea49222fafc3999987d7899fc676f688dd12252509434636622"}, - {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86b8e7f703ec6ff4f351cfdb9f428955859537125904aa8c963604f2e9d3e7"}, - {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0adfbd36003d9f86c8c97110039f7539b379f28656a04097e7434d3eaf9aa131"}, - {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:63107013578c8a730419adc05608756c3fa640bdc6abe806c3123a49fb829f43"}, - {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:683b94dbd1ca67557850b86423318a2e323511648f9f3f7b1840408a02b9a48c"}, - {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5d2a01dcce81789cf4b12d478b5464632204f4c834dc2d064902ee27d2d1f0ee"}, - {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:a9d360a792cbcce2fe7b66b8d51274ec297c53cbc423401480e53b26161a290d"}, - {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:f0b48edbebea1b7421a9c687c304f7b44d0677c46498a046079d445454504737"}, - {file = "xxhash-3.5.0-cp37-cp37m-win32.whl", hash = "sha256:7ccb800c9418e438b44b060a32adeb8393764da7441eb52aa2aa195448935306"}, - {file = "xxhash-3.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c3bc7bf8cb8806f8d1c9bf149c18708cb1c406520097d6b0a73977460ea03602"}, - {file = "xxhash-3.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:74752ecaa544657d88b1d1c94ae68031e364a4d47005a90288f3bab3da3c970f"}, - {file = "xxhash-3.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dee1316133c9b463aa81aca676bc506d3f80d8f65aeb0bba2b78d0b30c51d7bd"}, - {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:602d339548d35a8579c6b013339fb34aee2df9b4e105f985443d2860e4d7ffaa"}, - {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:695735deeddfb35da1677dbc16a083445360e37ff46d8ac5c6fcd64917ff9ade"}, - {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1030a39ba01b0c519b1a82f80e8802630d16ab95dc3f2b2386a0b5c8ed5cbb10"}, - {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5bc08f33c4966f4eb6590d6ff3ceae76151ad744576b5fc6c4ba8edd459fdec"}, - {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160e0c19ee500482ddfb5d5570a0415f565d8ae2b3fd69c5dcfce8a58107b1c3"}, - {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f1abffa122452481a61c3551ab3c89d72238e279e517705b8b03847b1d93d738"}, - {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:d5e9db7ef3ecbfc0b4733579cea45713a76852b002cf605420b12ef3ef1ec148"}, - {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:23241ff6423378a731d84864bf923a41649dc67b144debd1077f02e6249a0d54"}, - {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:82b833d5563fefd6fceafb1aed2f3f3ebe19f84760fdd289f8b926731c2e6e91"}, - {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0a80ad0ffd78bef9509eee27b4a29e56f5414b87fb01a888353e3d5bda7038bd"}, - {file = "xxhash-3.5.0-cp38-cp38-win32.whl", hash = "sha256:50ac2184ffb1b999e11e27c7e3e70cc1139047e7ebc1aa95ed12f4269abe98d4"}, - {file = "xxhash-3.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:392f52ebbb932db566973693de48f15ce787cabd15cf6334e855ed22ea0be5b3"}, - {file = "xxhash-3.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc8cdd7f33d57f0468b0614ae634cc38ab9202c6957a60e31d285a71ebe0301"}, - {file = "xxhash-3.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0c48b6300cd0b0106bf49169c3e0536408dfbeb1ccb53180068a18b03c662ab"}, - {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe1a92cfbaa0a1253e339ccec42dbe6db262615e52df591b68726ab10338003f"}, - {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33513d6cc3ed3b559134fb307aae9bdd94d7e7c02907b37896a6c45ff9ce51bd"}, - {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eefc37f6138f522e771ac6db71a6d4838ec7933939676f3753eafd7d3f4c40bc"}, - {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a606c8070ada8aa2a88e181773fa1ef17ba65ce5dd168b9d08038e2a61b33754"}, - {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42eca420c8fa072cc1dd62597635d140e78e384a79bb4944f825fbef8bfeeef6"}, - {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:604253b2143e13218ff1ef0b59ce67f18b8bd1c4205d2ffda22b09b426386898"}, - {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6e93a5ad22f434d7876665444a97e713a8f60b5b1a3521e8df11b98309bff833"}, - {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:7a46e1d6d2817ba8024de44c4fd79913a90e5f7265434cef97026215b7d30df6"}, - {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:30eb2efe6503c379b7ab99c81ba4a779748e3830241f032ab46bd182bf5873af"}, - {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c8aa771ff2c13dd9cda8166d685d7333d389fae30a4d2bb39d63ab5775de8606"}, - {file = "xxhash-3.5.0-cp39-cp39-win32.whl", hash = "sha256:5ed9ebc46f24cf91034544b26b131241b699edbfc99ec5e7f8f3d02d6eb7fba4"}, - {file = "xxhash-3.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:220f3f896c6b8d0316f63f16c077d52c412619e475f9372333474ee15133a558"}, - {file = "xxhash-3.5.0-cp39-cp39-win_arm64.whl", hash = "sha256:a7b1d8315d9b5e9f89eb2933b73afae6ec9597a258d52190944437158b49d38e"}, - {file = "xxhash-3.5.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2014c5b3ff15e64feecb6b713af12093f75b7926049e26a580e94dcad3c73d8c"}, - {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab81ef75003eda96239a23eda4e4543cedc22e34c373edcaf744e721a163986"}, - {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e2febf914ace002132aa09169cc572e0d8959d0f305f93d5828c4836f9bc5a6"}, - {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d3a10609c51da2a1c0ea0293fc3968ca0a18bd73838455b5bca3069d7f8e32b"}, - {file = "xxhash-3.5.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a74f23335b9689b66eb6dbe2a931a88fcd7a4c2cc4b1cb0edba8ce381c7a1da"}, - {file = "xxhash-3.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2b4154c00eb22e4d543f472cfca430e7962a0f1d0f3778334f2e08a7ba59363c"}, - {file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d30bbc1644f726b825b3278764240f449d75f1a8bdda892e641d4a688b1494ae"}, - {file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa0b72f2423e2aa53077e54a61c28e181d23effeaafd73fcb9c494e60930c8e"}, - {file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:13de2b76c1835399b2e419a296d5b38dc4855385d9e96916299170085ef72f57"}, - {file = "xxhash-3.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:0691bfcc4f9c656bcb96cc5db94b4d75980b9d5589f2e59de790091028580837"}, - {file = "xxhash-3.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:297595fe6138d4da2c8ce9e72a04d73e58725bb60f3a19048bc96ab2ff31c692"}, - {file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc1276d369452040cbb943300dc8abeedab14245ea44056a2943183822513a18"}, - {file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2061188a1ba352fc699c82bff722f4baacb4b4b8b2f0c745d2001e56d0dfb514"}, - {file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38c384c434021e4f62b8d9ba0bc9467e14d394893077e2c66d826243025e1f81"}, - {file = "xxhash-3.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e6a4dd644d72ab316b580a1c120b375890e4c52ec392d4aef3c63361ec4d77d1"}, - {file = "xxhash-3.5.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:531af8845aaadcadf951b7e0c1345c6b9c68a990eeb74ff9acd8501a0ad6a1c9"}, - {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ce379bcaa9fcc00f19affa7773084dd09f5b59947b3fb47a1ceb0179f91aaa1"}, - {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd1b2281d01723f076df3c8188f43f2472248a6b63118b036e641243656b1b0f"}, - {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c770750cc80e8694492244bca7251385188bc5597b6a39d98a9f30e8da984e0"}, - {file = "xxhash-3.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b150b8467852e1bd844387459aa6fbe11d7f38b56e901f9f3b3e6aba0d660240"}, - {file = "xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f"}, + {file = "xxhash-3.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91dbfa55346ad3e18e738742236554531a621042e419b70ad8f3c1d9c7a16e7f"}, + {file = "xxhash-3.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:665a65c2a48a72068fcc4d21721510df5f51f1142541c890491afc80451636d2"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb11628470a6004dc71a09fe90c2f459ff03d611376c1debeec2d648f44cb693"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bef2a7dc7b4f4beb45a1edbba9b9194c60a43a89598a87f1a0226d183764189"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0f7b2d547d72c7eda7aa817acf8791f0146b12b9eba1d4432c531fb0352228"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00f2fdef6b41c9db3d2fc0e7f94cb3db86693e5c45d6de09625caad9a469635b"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23cfd9ca09acaf07a43e5a695143d9a21bf00f5b49b15c07d5388cadf1f9ce11"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6a9ff50a3cf88355ca4731682c168049af1ca222d1d2925ef7119c1a78e95b3b"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f1d7c69a1e9ca5faa75546fdd267f214f63f52f12692f9b3a2f6467c9e67d5e7"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:672b273040d5d5a6864a36287f3514efcd1d4b1b6a7480f294c4b1d1ee1b8de0"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4178f78d70e88f1c4a89ff1ffe9f43147185930bb962ee3979dba15f2b1cc799"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9804b9eb254d4b8cc83ab5a2002128f7d631dd427aa873c8727dba7f1f0d1c2b"}, + {file = "xxhash-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c09c49473212d9c87261d22c74370457cfff5db2ddfc7fd1e35c80c31a8c14ce"}, + {file = "xxhash-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:ebbb1616435b4a194ce3466d7247df23499475c7ed4eb2681a1fa42ff766aff6"}, + {file = "xxhash-3.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:25dc66be3db54f8a2d136f695b00cfe88018e59ccff0f3b8f545869f376a8a46"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58c49083801885273e262c0f5bbeac23e520564b8357fbb18fb94ff09d3d3ea5"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b526015a973bfbe81e804a586b703f163861da36d186627e27524f5427b0d520"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ad4457644c91a966f6fe137d7467636bdc51a6ce10a1d04f365c70d6a16d7e"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:248d3e83d119770f96003271fe41e049dd4ae52da2feb8f832b7a20e791d2920"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2070b6d5bbef5ee031666cf21d4953c16e92c2f8a24a94b5c240f8995ba3b1d0"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2746035f518f0410915e247877f7df43ef3372bf36cfa52cc4bc33e85242641"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ba6181514681c2591840d5632fcf7356ab287d4aff1c8dea20f3c78097088"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aac5010869240e95f740de43cd6a05eae180c59edd182ad93bf12ee289484fa"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4cb11d8debab1626181633d184b2372aaa09825bde709bf927704ed72765bed1"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b29728cff2c12f3d9f1d940528ee83918d803c0567866e062683f300d1d2eff3"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a15cbf3a9c40672523bdb6ea97ff74b443406ba0ab9bca10ceccd9546414bd84"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e66df260fed01ed8ea790c2913271641c58481e807790d9fca8bfd5a3c13844"}, + {file = "xxhash-3.4.1-cp311-cp311-win32.whl", hash = "sha256:e867f68a8f381ea12858e6d67378c05359d3a53a888913b5f7d35fbf68939d5f"}, + {file = "xxhash-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:200a5a3ad9c7c0c02ed1484a1d838b63edcf92ff538770ea07456a3732c577f4"}, + {file = "xxhash-3.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:1d03f1c0d16d24ea032e99f61c552cb2b77d502e545187338bea461fde253583"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4bbba9b182697a52bc0c9f8ec0ba1acb914b4937cd4a877ad78a3b3eeabefb3"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9fd28a9da300e64e434cfc96567a8387d9a96e824a9be1452a1e7248b7763b78"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6066d88c9329ab230e18998daec53d819daeee99d003955c8db6fc4971b45ca3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93805bc3233ad89abf51772f2ed3355097a5dc74e6080de19706fc447da99cd3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64da57d5ed586ebb2ecdde1e997fa37c27fe32fe61a656b77fabbc58e6fbff6e"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a97322e9a7440bf3c9805cbaac090358b43f650516486746f7fa482672593df"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe750d512982ee7d831838a5dee9e9848f3fb440e4734cca3f298228cc957a6"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fd79d4087727daf4d5b8afe594b37d611ab95dc8e29fe1a7517320794837eb7d"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:743612da4071ff9aa4d055f3f111ae5247342931dedb955268954ef7201a71ff"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b41edaf05734092f24f48c0958b3c6cbaaa5b7e024880692078c6b1f8247e2fc"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:a90356ead70d715fe64c30cd0969072de1860e56b78adf7c69d954b43e29d9fa"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac56eebb364e44c85e1d9e9cc5f6031d78a34f0092fea7fc80478139369a8b4a"}, + {file = "xxhash-3.4.1-cp312-cp312-win32.whl", hash = "sha256:911035345932a153c427107397c1518f8ce456f93c618dd1c5b54ebb22e73747"}, + {file = "xxhash-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:f31ce76489f8601cc7b8713201ce94b4bd7b7ce90ba3353dccce7e9e1fee71fa"}, + {file = "xxhash-3.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:b5beb1c6a72fdc7584102f42c4d9df232ee018ddf806e8c90906547dfb43b2da"}, + {file = "xxhash-3.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d42b24d1496deb05dee5a24ed510b16de1d6c866c626c2beb11aebf3be278b9"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b685fab18876b14a8f94813fa2ca80cfb5ab6a85d31d5539b7cd749ce9e3624"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419ffe34c17ae2df019a4685e8d3934d46b2e0bbe46221ab40b7e04ed9f11137"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e041ce5714f95251a88670c114b748bca3bf80cc72400e9f23e6d0d59cf2681"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc860d887c5cb2f524899fb8338e1bb3d5789f75fac179101920d9afddef284b"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:312eba88ffe0a05e332e3a6f9788b73883752be63f8588a6dc1261a3eaaaf2b2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e01226b6b6a1ffe4e6bd6d08cfcb3ca708b16f02eb06dd44f3c6e53285f03e4f"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9f3025a0d5d8cf406a9313cd0d5789c77433ba2004b1c75439b67678e5136537"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:6d3472fd4afef2a567d5f14411d94060099901cd8ce9788b22b8c6f13c606a93"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:43984c0a92f06cac434ad181f329a1445017c33807b7ae4f033878d860a4b0f2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a55e0506fdb09640a82ec4f44171273eeabf6f371a4ec605633adb2837b5d9d5"}, + {file = "xxhash-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:faec30437919555b039a8bdbaba49c013043e8f76c999670aef146d33e05b3a0"}, + {file = "xxhash-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:c9e1b646af61f1fc7083bb7b40536be944f1ac67ef5e360bca2d73430186971a"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:961d948b7b1c1b6c08484bbce3d489cdf153e4122c3dfb07c2039621243d8795"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:719a378930504ab159f7b8e20fa2aa1896cde050011af838af7e7e3518dd82de"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74fb5cb9406ccd7c4dd917f16630d2e5e8cbbb02fc2fca4e559b2a47a64f4940"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dab508ac39e0ab988039bc7f962c6ad021acd81fd29145962b068df4148c476"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c59f3e46e7daf4c589e8e853d700ef6607afa037bfad32c390175da28127e8c"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc07256eff0795e0f642df74ad096f8c5d23fe66bc138b83970b50fc7f7f6c5"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9f749999ed80f3955a4af0eb18bb43993f04939350b07b8dd2f44edc98ffee9"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7688d7c02149a90a3d46d55b341ab7ad1b4a3f767be2357e211b4e893efbaaf6"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a8b4977963926f60b0d4f830941c864bed16aa151206c01ad5c531636da5708e"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8106d88da330f6535a58a8195aa463ef5281a9aa23b04af1848ff715c4398fb4"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4c76a77dbd169450b61c06fd2d5d436189fc8ab7c1571d39265d4822da16df22"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:11f11357c86d83e53719c592021fd524efa9cf024dc7cb1dfb57bbbd0d8713f2"}, + {file = "xxhash-3.4.1-cp38-cp38-win32.whl", hash = "sha256:0c786a6cd74e8765c6809892a0d45886e7c3dc54de4985b4a5eb8b630f3b8e3b"}, + {file = "xxhash-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:aabf37fb8fa27430d50507deeab2ee7b1bcce89910dd10657c38e71fee835594"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6127813abc1477f3a83529b6bbcfeddc23162cece76fa69aee8f6a8a97720562"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef2e194262f5db16075caea7b3f7f49392242c688412f386d3c7b07c7733a70a"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71be94265b6c6590f0018bbf73759d21a41c6bda20409782d8117e76cd0dfa8b"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10e0a619cdd1c0980e25eb04e30fe96cf8f4324758fa497080af9c21a6de573f"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa122124d2e3bd36581dd78c0efa5f429f5220313479fb1072858188bc2d5ff1"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17032f5a4fea0a074717fe33477cb5ee723a5f428de7563e75af64bfc1b1e10"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca7783b20e3e4f3f52f093538895863f21d18598f9a48211ad757680c3bd006f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d77d09a1113899fad5f354a1eb4f0a9afcf58cefff51082c8ad643ff890e30cf"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:21287bcdd299fdc3328cc0fbbdeaa46838a1c05391264e51ddb38a3f5b09611f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:dfd7a6cc483e20b4ad90224aeb589e64ec0f31e5610ab9957ff4314270b2bf31"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:543c7fcbc02bbb4840ea9915134e14dc3dc15cbd5a30873a7a5bf66039db97ec"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fe0a98d990e433013f41827b62be9ab43e3cf18e08b1483fcc343bda0d691182"}, + {file = "xxhash-3.4.1-cp39-cp39-win32.whl", hash = "sha256:b9097af00ebf429cc7c0e7d2fdf28384e4e2e91008130ccda8d5ae653db71e54"}, + {file = "xxhash-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:d699b921af0dcde50ab18be76c0d832f803034d80470703700cb7df0fbec2832"}, + {file = "xxhash-3.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:2be491723405e15cc099ade1280133ccfbf6322d2ef568494fb7d07d280e7eee"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:431625fad7ab5649368c4849d2b49a83dc711b1f20e1f7f04955aab86cd307bc"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc6dbd5fc3c9886a9e041848508b7fb65fd82f94cc793253990f81617b61fe49"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ff8dbd0ec97aec842476cb8ccc3e17dd288cd6ce3c8ef38bff83d6eb927817"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef73a53fe90558a4096e3256752268a8bdc0322f4692ed928b6cd7ce06ad4fe3"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:450401f42bbd274b519d3d8dcf3c57166913381a3d2664d6609004685039f9d3"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a162840cf4de8a7cd8720ff3b4417fbc10001eefdd2d21541a8226bb5556e3bb"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b736a2a2728ba45017cb67785e03125a79d246462dfa892d023b827007412c52"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0ae4c2e7698adef58710d6e7a32ff518b66b98854b1c68e70eee504ad061d8"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6322c4291c3ff174dcd104fae41500e75dad12be6f3085d119c2c8a80956c51"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:dd59ed668801c3fae282f8f4edadf6dc7784db6d18139b584b6d9677ddde1b6b"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92693c487e39523a80474b0394645b393f0ae781d8db3474ccdcead0559ccf45"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4603a0f642a1e8d7f3ba5c4c25509aca6a9c1cc16f85091004a7028607ead663"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa45e8cbfbadb40a920fe9ca40c34b393e0b067082d94006f7f64e70c7490a6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:595b252943b3552de491ff51e5bb79660f84f033977f88f6ca1605846637b7c6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:562d8b8f783c6af969806aaacf95b6c7b776929ae26c0cd941d54644ea7ef51e"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:41ddeae47cf2828335d8d991f2d2b03b0bdc89289dc64349d712ff8ce59d0647"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c44d584afdf3c4dbb3277e32321d1a7b01d6071c1992524b6543025fb8f4206f"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd7bddb3a5b86213cc3f2c61500c16945a1b80ecd572f3078ddbbe68f9dabdfb"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ecb6c987b62437c2f99c01e97caf8d25660bf541fe79a481d05732e5236719c"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:696b4e18b7023527d5c50ed0626ac0520edac45a50ec7cf3fc265cd08b1f4c03"}, + {file = "xxhash-3.4.1.tar.gz", hash = "sha256:0379d6cf1ff987cd421609a264ce025e74f346e3e145dd106c0cc2e3ec3f99a9"}, ] [[package]] name = "yarl" version = "1.9.4" description = "Yet another URL library" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, @@ -3722,22 +3478,15 @@ files = [ idna = ">=2.0" multidict = ">=4.0" -[[package]] -name = "zipp" -version = "3.20.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.8" -files = [ - {file = "zipp-3.20.0-py3-none-any.whl", hash = "sha256:58da6168be89f0be59beb194da1250516fdaa062ccebd30127ac65d30045e10d"}, - {file = "zipp-3.20.0.tar.gz", hash = "sha256:0145e43d89664cfe1a2e533adc75adafed82fe2da404b4bbb6b026c0157bdb31"}, -] - -[package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +[extras] +accelerate = ["accelerate"] +bnb = ["bitsandbytes"] +outlines = ["outlines"] +peft = ["peft"] +quantize = ["accelerate", "datasets", "texttable"] +torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "6d54ae94305899e548883c4700e93babc406ec896c114b45978111c514531906" +content-hash = "06e67944a2b1cf9884a31e771d0e9d89877e9b3c91894982cb67d104cb834758" diff --git a/server/pyproject.toml b/server/pyproject.toml index c1809550..6a3a0205 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -9,34 +9,58 @@ text-generation-server = 'text_generation_server.cli:app' [tool.poetry.dependencies] python = ">=3.9,<3.13" -protobuf = "^3.20.3" +protobuf = "^4.21.7" grpcio = "^1.51.1" -grpcio-status = "*" -grpcio-reflection = "*" +grpcio-status = "^1.51.1" +grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" -typer = "^0.7.0" +typer = "^0.6.1" +accelerate = { version = "^0.29.1", optional = true } +bitsandbytes = { version = "^0.43.0", optional = true } +safetensors = "^0.4" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" -peft = "^0.10" -optimum-habana = "1.13.2" -transformers = "4.43.4" -numpy = "1.26.4" -accelerate = "0.33.0" +tokenizers = "^0.19.1" +huggingface-hub = "^0.23" +transformers = "^4.41" +einops = "^0.6.1" +texttable = { version = "^1.6.7", optional = true } +datasets = { version = "^2.14.0", optional = true } +peft = { version = "^0.10", optional = true } +torch = { version = "^2.3.0", optional = true } +scipy = "^1.11.1" +pillow = "^10.0.0" outlines= { version = "^0.0.36", optional = true } prometheus-client = "^0.20.0" py-cpuinfo = "^9.0.0" +[tool.poetry.extras] +torch = ["torch"] +accelerate = ["accelerate"] +bnb = ["bitsandbytes"] +peft = ["peft"] +quantize = ["texttable", "datasets", "accelerate"] +outlines = ["outlines"] + [tool.poetry.group.dev.dependencies] -grpcio-tools = "*" +grpcio-tools = "^1.51.1" pytest = "^7.3.0" + +[[tool.poetry.source]] +name = "pytorch-gpu-src" +url = "https://download.pytorch.org/whl/cu121" +priority = "explicit" + [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] [build-system] -requires = ["poetry-core>=1.0.0"] +requires = [ + "poetry-core>=1.0.0", +] build-backend = "poetry.core.masonry.api" diff --git a/server/requirements.txt b/server/requirements.txt deleted file mode 100644 index 0a091a2f..00000000 --- a/server/requirements.txt +++ /dev/null @@ -1,88 +0,0 @@ -accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" -aiohappyeyeballs==2.4.0 ; python_version >= "3.9" and python_version < "3.13" -aiohttp==3.10.5 ; python_version >= "3.9" and python_version < "3.13" -aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" -async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11" -attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13" -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" -charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" -click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13" -datasets==2.21.0 ; python_version >= "3.9" and python_version < "3.13" -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -diffusers==0.29.2 ; python_version >= "3.9" and python_version < "3.13" -dill==0.3.8 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" -frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" -grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.66.0 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.24.6 ; python_version >= "3.9" and python_version < "3.13" -humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13" -idna==3.8 ; python_version >= "3.9" and python_version < "3.13" -importlib-metadata==8.4.0 ; python_version >= "3.9" and python_version < "3.13" -jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13" -joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" -multidict==6.0.5 ; python_version >= "3.9" and python_version < "3.13" -multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "3.13" -networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -optimum-habana==1.13.2 ; python_version >= "3.9" and python_version < "3.13" -optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" -pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13" -peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13" -psutil==6.0.0 ; python_version >= "3.9" and python_version < "3.13" -py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" -pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13" -pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13" -python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13" -pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" -regex==2024.7.24 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.4.4 ; python_version >= "3.9" and python_version < "3.13" -scikit-learn==1.5.1 ; python_version >= "3.9" and python_version < "3.13" -scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" -sentence-transformers[train]==3.0.1 ; python_version >= "3.9" and python_version < "3.13" -sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==73.0.1 ; python_version >= "3.9" and python_version < "3.13" -six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13" -threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13" -transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13" -triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9" -typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" -tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" -wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13" -yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13" -zipp==3.20.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 4b7dde81..66df708a 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -1,5 +1,3 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import pytest import torch @@ -7,9 +5,9 @@ from copy import copy from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch, PAD_SEQUENCE_TO_MULTIPLE_OF +from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights -from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM +from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded @pytest.fixture(scope="session") @@ -18,7 +16,7 @@ def default_bloom(): revision = "main" filenames = weight_hub_files(model_id, revision, ".safetensors") download_weights(filenames, model_id, revision) - return BLOOM(model_id) + return BLOOMSharded(model_id) @pytest.fixture(scope="session") @@ -32,7 +30,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): id=0, inputs="Test", prefill_logprobs=True, - truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -46,7 +44,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): return BloomCausalLMBatch.from_pb( - default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("hpu") + default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu") ) @@ -60,7 +58,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return BloomCausalLMBatch.from_pb( - batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("hpu") + batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu") ) @@ -68,29 +66,30 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch): batch = default_bloom_batch assert batch.batch_id == default_pb_batch.id - assert len(batch.requests) == len(default_pb_batch.requests) == default_pb_batch.size - for request, pb_request in zip(batch.requests, default_pb_batch.requests): - assert request.data == pb_request + assert batch.requests == default_pb_batch.requests - assert batch.input_ids[0][-1] == 3 - assert batch.input_ids[0][-2] == 10264 - assert torch.all(batch.input_ids[0][:-2] == 3) + assert len(batch.input_ids) == default_pb_batch.size + assert batch.input_ids[0][-1] == 10264 + assert torch.all(batch.input_ids[0][:-1] == 3) - assert batch.attention_mask[0][-1] == 0 - assert batch.attention_mask[0][-2] == 1 - assert torch.all(batch.attention_mask[0][:-2] == 0) + assert batch.attention_mask[0][0] == 1 + assert torch.all(batch.attention_mask[0][1:] == 0) assert batch.past_key_values is None assert all( [ - torch.equal(input_ids, request.all_input_ids[:batch.input_length+1, 0]) - for input_ids, request in zip(batch.input_ids, batch.requests) + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) ] ) + assert batch.input_lengths == [1] + assert len(batch) == default_pb_batch.size - assert batch.max_input_length == batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF - 1 + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + + assert batch.max_input_length == batch.input_lengths[0] def test_batch_concatenate_no_prefill(default_bloom_batch): @@ -102,7 +101,6 @@ def test_causal_lm_batch_type(default_bloom): assert default_bloom.batch_type == BloomCausalLMBatch -@pytest.mark.skip def test_causal_lm_generate_token(default_bloom, default_bloom_batch): sequence_length = len(default_bloom_batch.all_input_ids[0]) generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch) @@ -152,7 +150,6 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert generations[0].request_id == 0 -@pytest.mark.skip def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): next_batch = default_bloom_batch for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): @@ -173,7 +170,6 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) ) -@pytest.mark.skip def test_causal_lm_generate_token_completion_multi( default_bloom, default_multi_requests_bloom_batch ): @@ -224,7 +220,6 @@ def test_causal_lm_generate_token_completion_multi( ) -@pytest.mark.skip def test_batch_concatenate( default_bloom, default_bloom_batch, default_multi_requests_bloom_batch ): diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6a0a474a..250fa354 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -1,5 +1,3 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import pytest import torch @@ -7,27 +5,19 @@ from copy import copy from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models import get_model -from text_generation_server.models.causal_lm import ( - CausalLMBatch, - PREFILL_BATCH_BUCKET_SIZE, - PAD_SEQUENCE_TO_MULTIPLE_OF, - MAX_TOTAL_TOKENS, - BATCH_BUCKET_SIZE, -) - -PAD_TOKEN=0 +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture(scope="session") def default_causal_lm(): - return get_model("meta-llama/Llama-2-7b-hf", None, None, None, None) + return CausalLM("gpt2") @pytest.fixture(scope="session") -def default_tokenizer(default_causal_lm): - default_causal_lm.tokenizer.pad_token_id = PAD_TOKEN - return default_causal_lm.tokenizer +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer @pytest.fixture @@ -36,7 +26,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): id=0, inputs="Test", prefill_logprobs=True, - truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -48,14 +38,14 @@ def default_pb_batch(default_pb_request): @pytest.fixture -def default_causal_lm_batch(default_pb_batch, default_tokenizer): +def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): return CausalLMBatch.from_pb( - default_pb_batch, default_tokenizer, torch.float32, torch.device("hpu") + default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu") ) @pytest.fixture -def default_multi_requests_causal_lm_batch(default_pb_request, default_tokenizer): +def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_0 = copy(default_pb_request) req_0.id = 1 req_1 = default_pb_request @@ -64,7 +54,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, default_tokenizer batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) return CausalLMBatch.from_pb( - batch_pb, default_tokenizer, torch.float32, torch.device("hpu") + batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu") ) @@ -72,37 +62,30 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): batch = default_causal_lm_batch assert batch.batch_id == default_pb_batch.id - assert len(batch.requests) == len(default_pb_batch.requests) + assert batch.requests == default_pb_batch.requests - for r in range(0,len(default_pb_batch.requests)): - assert batch.requests[r].data == default_pb_batch.requests[r] + assert len(batch.input_ids) == default_pb_batch.size + assert batch.input_ids[0][-1] == 14402 + assert torch.all(batch.input_ids[0][:-1] == 50256) - # For Gaudi we are adding padding of multiplication of bucket size - size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE - - assert len(batch.input_ids) == size_of_padded_to_bucket - - assert batch.input_ids[0][-2] == 4321 - assert batch.input_ids[0][-3] == 1 - assert torch.all(batch.input_ids[0][:-3] == PAD_TOKEN) - assert batch.input_ids[0][-1] == PAD_TOKEN - - assert batch.attention_mask[0][-1] == 0 - assert batch.attention_mask[0, -2] == 1 - assert batch.attention_mask[0, -3] == 1 - assert torch.all(batch.attention_mask[0, :-3] == 0) + assert batch.attention_mask[0, 0] == 1 + assert torch.all(batch.attention_mask[0, 1:] == 0) assert batch.past_key_values is None + assert all( [ - torch.equal(input_ids.to('cpu'), request.all_input_ids[:batch.input_length + 1, 0]) - for input_ids, request in zip(batch.input_ids, batch.requests) + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) ] ) - assert len(batch) == default_pb_batch.size + assert batch.input_lengths == [1] - assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + + assert batch.max_input_length == batch.input_lengths[0] def test_batch_concatenate_no_prefill(default_causal_lm_batch): @@ -115,66 +98,73 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) + generations, next_batch, _ = default_causal_lm.generate_token( + default_causal_lm_batch + ) - sequence_length = len(default_causal_lm_batch.requests[0].all_input_ids) - generations, next_batch, _ = default_causal_lm.generate_token([default_causal_lm_batch]) - padding = next_batch.requests[0].stopping_criteria.max_new_tokens - + assert len(generations) == len(next_batch) assert isinstance(next_batch, CausalLMBatch) - assert len(next_batch.attention_mask[0]) == PAD_SEQUENCE_TO_MULTIPLE_OF - assert next_batch.requests[0].all_input_ids[-padding-2] == 4321 - assert torch.all(next_batch.requests[0].all_input_ids[-padding-1:] == PAD_TOKEN) - assert torch.all(next_batch.requests[0].all_input_ids[:-padding-3] == PAD_TOKEN) + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + assert len(next_batch.attention_mask[0]) == 11 + assert next_batch.all_input_ids[0][-1] == 13 + assert next_batch.all_input_ids[0][-2] == 14402 + assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) - generations, next_batch, _ = default_causal_lm.generate_token([default_causal_lm_batch]) - assert torch.all(next_batch.attention_mask[0][PAD_SEQUENCE_TO_MULTIPLE_OF-3:PAD_SEQUENCE_TO_MULTIPLE_OF] == 1) - assert torch.all(next_batch.attention_mask[0][:PAD_SEQUENCE_TO_MULTIPLE_OF-3] == 0) - assert torch.all(next_batch.attention_mask[0][PAD_SEQUENCE_TO_MULTIPLE_OF+1:] == 0) + assert torch.all(next_batch.attention_mask[0][0:2] == 1) + assert torch.all(next_batch.attention_mask[0][2:] == 0) - assert next_batch.requests[0].all_input_ids[-padding-2] == 4321 - assert next_batch.requests[0].all_input_ids[-padding-1] == 292 - assert torch.all(next_batch.requests[0].all_input_ids[-padding:] == PAD_TOKEN) - assert torch.all(next_batch.requests[0].all_input_ids[:-padding-3] == PAD_TOKEN) + assert next_batch.input_ids.shape == (len(next_batch), 1) + assert next_batch.input_ids[0, 0] == 13 - assert next_batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF - assert next_batch.max_input_length == next_batch.input_length + assert next_batch.input_lengths == [2] + assert next_batch.max_input_length == next_batch.input_lengths[0] assert next_batch.past_key_values is not None assert all( - [p[0].shape == (BATCH_BUCKET_SIZE, 32, MAX_TOTAL_TOKENS, 128) for p in next_batch.past_key_values] + [p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] ) assert all( - [p[1].shape == (BATCH_BUCKET_SIZE, 32, MAX_TOTAL_TOKENS, 128) for p in next_batch.past_key_values] + [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] ) assert all([generation.generated_text is None for generation in generations]) - assert all([len(generation.prefill_tokens) == PAD_SEQUENCE_TO_MULTIPLE_OF-1 for generation in generations]) - assert all([generation.tokens.token_ids[0] == 292 for generation in generations]) - assert all([generation.tokens.texts[0] == "ing" for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all( + [ + token_id.item() == 13 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == "." + for generation in generations + for token_text in generation.tokens.texts + ] + ) assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion( default_causal_lm, default_causal_lm_batch ): - next_batch = default_causal_lm_batch - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - - for _ in range(default_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens - 1): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "ing the effect of a new method for the detection" - assert generations[0].request_id == default_causal_lm_batch.requests[0].data.id + assert generations[0].generated_text.text == ".java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( generations[0].generated_text.generated_tokens - == default_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -182,49 +172,51 @@ def test_causal_lm_generate_token_completion_multi( default_causal_lm, default_multi_requests_causal_lm_batch ): next_batch = default_multi_requests_causal_lm_batch - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) for i in range( - default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens - 1 + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 - assert generations[1].generated_text.text == "ing the effect of a" + assert generations[1].generated_text.text == ".java:784)" assert ( generations[1].request_id - == default_multi_requests_causal_lm_batch.requests[1].data.id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( generations[1].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) + # Copy stopping_criterias before filtering + stopping_criterias = ( + default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter([next_batch.requests[0].data.id]) + next_batch = next_batch.filter([next_batch.requests[0].id]) for _ in range( - default_multi_requests_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens - default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens - 1 + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "ing the effect of a new method for the detection" + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( generations[0].request_id - == default_multi_requests_causal_lm_batch.requests[0].data.id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( generations[0].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -232,13 +224,11 @@ def test_batch_concatenate( default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch ): next_batch_0 = default_causal_lm_batch - _, next_batch_0, _ = default_causal_lm.generate_token([next_batch_0]) - _, next_batch_0, _ = default_causal_lm.generate_token([next_batch_0]) - _, next_batch_0, _ = default_causal_lm.generate_token([next_batch_0]) + _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0) next_batch_1 = default_multi_requests_causal_lm_batch - _, next_batch_1, _ = default_causal_lm.generate_token([next_batch_1]) - _, next_batch_1, _ = default_causal_lm.generate_token([next_batch_1]) + _, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1) # Clone past_key_values before concatenating to compare after, # because they are removed from the concatenated batches @@ -251,135 +241,113 @@ def test_batch_concatenate( next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) - assert torch.equal(next_batch.requests[0].all_input_ids, next_batch_0.requests[0].all_input_ids) - assert torch.equal(next_batch.requests[1].all_input_ids, next_batch_1.requests[0].all_input_ids) - assert torch.equal(next_batch.requests[2].all_input_ids, next_batch_1.requests[1].all_input_ids) - + assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) assert torch.all( - next_batch.attention_mask[0:2, -next_batch.right_padding - 3: -next_batch.right_padding] == 1 + next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1 ) assert torch.all( - next_batch.attention_mask[2, -next_batch.right_padding - 4: -next_batch.right_padding] == 1 + next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1 ) - assert torch.all( - next_batch.attention_mask[3, -next_batch.right_padding - 3: -next_batch.right_padding] == 1 - ) - - assert torch.all( - next_batch.attention_mask[0:2, :-next_batch.right_padding-3] == 0) - assert torch.all( - next_batch.attention_mask[2, :-next_batch.right_padding-4] == 0) - assert torch.all( - next_batch.attention_mask[3, :-next_batch.right_padding-3] == 0) + assert torch.all(next_batch.attention_mask[1:, 3:] == 0) assert next_batch.batch_id == 0 - assert next_batch.input_ids[0,-next_batch.right_padding - 3] == 1 - assert next_batch.input_ids[0,-next_batch.right_padding - 2] == 4321 - assert next_batch.input_ids[0,-next_batch.right_padding - 1] == 292 + assert next_batch.input_ids[0, 0] == 12355 + assert torch.all(next_batch.input_ids[1:] == 13) - assert next_batch.max_input_length == 129 - - assert torch.all(next_batch.input_ids[0,-next_batch.right_padding:] == PAD_TOKEN) - assert torch.all(next_batch.input_ids[1,-next_batch.right_padding:] == PAD_TOKEN) - assert torch.all(next_batch.input_ids[2,-next_batch.right_padding:] == PAD_TOKEN) - assert torch.all(next_batch.input_ids[3,-next_batch.right_padding:] == PAD_TOKEN) - - assert next_batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF +1 - assert next_batch.max_input_length == PAD_SEQUENCE_TO_MULTIPLE_OF +1 + assert next_batch.input_lengths == [3, 2, 2] + assert next_batch.max_input_length == 3 assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[1:] == next_batch_1.requests - assert next_batch.requests[0].stopping_criteria == next_batch_0.requests[0].stopping_criteria - assert next_batch.requests[1].stopping_criteria == next_batch_1.requests[0].stopping_criteria - assert next_batch.requests[2].stopping_criteria == next_batch_1.requests[1].stopping_criteria - - assert next_batch.past_key_values is not None - - assert all([p[0].shape == (8, 32, 2048, 128) for p in next_batch.past_key_values]) - assert all([p[1].shape == (8, 32, 2048, 128) for p in next_batch.past_key_values]) + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias assert next_batch.past_key_values is not None + assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) + assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0_past_key_values[i][0][0, 0,0:128], past[0][0][0][1:129]) + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0]) assert torch.equal( - next_batch_1_past_key_values[i][0][:, :, 0:1][0], past[0][1:, :, 1 :2, :][0] + next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] ) - assert torch.equal(next_batch_0_past_key_values[i][1][0, 0,0:128], past[1][0][0][1:129]) + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0]) assert torch.equal( - next_batch_1_past_key_values[i][1][:, :, 0:1][0], past[1][1:, :, 1 :2, :][0] + next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] ) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - for _ in range( - default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens - 2 + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 3 - assert generations[2].generated_text.text == "ing the effect of a" - + assert generations[2].generated_text.text == ".java:784)" assert ( generations[2].request_id - == default_multi_requests_causal_lm_batch.requests[1].data.id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( generations[2].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) next_batch = next_batch.filter( - [next_batch.requests[0].data.id, next_batch.requests[1].data.id] + [next_batch.requests[0].id, next_batch.requests[1].id] ) for _ in range( - default_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens - - default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens + default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 - assert generations[0].generated_text.text == "ing the effect of a new method for the detection" - assert generations[0].request_id == default_causal_lm_batch.requests[0].data.id + assert generations[0].generated_text.text == ".java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( generations[0].generated_text.generated_tokens - == default_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].data.id]) + next_batch = next_batch.filter([next_batch.requests[1].id]) for _ in range( - default_multi_requests_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens - - default_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens - - default_multi_requests_causal_lm_batch.requests[1].stopping_criteria.max_new_tokens + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4 ): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "ing the effect of a new method for the detection" + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( generations[0].request_id - == default_multi_requests_causal_lm_batch.requests[0].data.id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( generations[0].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.requests[0].stopping_criteria.max_new_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_grammar.py b/server/tests/models/test_grammar.py deleted file mode 100644 index b5e65620..00000000 --- a/server/tests/models/test_grammar.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import json -import pytest -import torch - -from copy import copy - -from text_generation_server.pb import generate_pb2 -from text_generation_server.models import get_model -from text_generation_server.models.causal_lm import ( - CausalLMBatch, - PAD_SEQUENCE_TO_MULTIPLE_OF, -) - -PAD_TOKEN=0 - - -@pytest.fixture -def default_pb_grammar_parameters(): - grammar_schema = { - "properties": { - "activity": { - "type": "string" - }, - "animals": { - "items": { - "type":"string" - }, - "type": "array" - } - }, - "required": ["activity", "animals"] - } - return generate_pb2.NextTokenChooserParameters( - temperature=1.0, - repetition_penalty=1.3, - top_k=0, - top_p=1.0, - typical_p=1.0, - do_sample=False, - grammar_type=generate_pb2.GrammarType.GRAMMAR_TYPE_JSON, - grammar=json.dumps(grammar_schema).encode('utf-8'), - ) - - -@pytest.fixture(scope="session") -def default_grammar_response(): - return [ - 29912, 376, 29874, 312, 2068, 1115, 29871, 13, 29908, 29890, - 638, 292, 613, 259, 376, 273, 3039, 29879, 1115,518, 1678, - 376, 26169, 3284, 4117, 3284, 336, 617, 6150, 3108, 500, 2 - ] - - -@pytest.fixture(scope="session") -def default_causal_lm(): - return get_model("meta-llama/Llama-2-7b-hf", None, None, None, None) - - -@pytest.fixture(scope="session") -def default_tokenizer(default_causal_lm): - default_causal_lm.tokenizer.pad_token_id = PAD_TOKEN - return default_causal_lm.tokenizer - - -@pytest.fixture -def default_pb_request(default_pb_parameters): - return generate_pb2.Request( - id=0, - inputs="Test", - prefill_logprobs=True, - truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, - parameters=default_pb_parameters, - stopping_parameters=generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10), - ) - - -@pytest.fixture -def default_pb_grammar_request(default_pb_grammar_parameters): - return generate_pb2.Request( - id=1, - inputs=f"Please use the following JSON schema to generate the output: I saw a puppy a cat and a raccoon during my bike ride in the park", - prefill_logprobs=True, - truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, - parameters=default_pb_grammar_parameters, - stopping_parameters=generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=50), - ) - - -@pytest.fixture -def default_pb_batch(default_pb_request): - return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) - - -@pytest.fixture -def default_pb_grammar_batch(default_pb_grammar_request): - return generate_pb2.Batch(id=1, requests=[default_pb_grammar_request], size=1) - - -@pytest.fixture -def default_causal_lm_batch(default_pb_batch, default_tokenizer): - return CausalLMBatch.from_pb( - default_pb_batch, default_tokenizer, torch.float32, torch.device("hpu") - ) - - -@pytest.fixture -def default_causal_lm_grammar_batch(default_pb_grammar_batch, default_tokenizer): - return CausalLMBatch.from_pb( - default_pb_grammar_batch, default_tokenizer, torch.float32, torch.device("hpu") - ) - - -@pytest.fixture -def default_two_causal_lm_grammar_batches(default_pb_grammar_request, default_tokenizer): - req_0 = default_pb_grammar_request - req_0.id = 0 - req_1 = copy(default_pb_grammar_request) - req_1.id = 1 - - batch_0 = generate_pb2.Batch(id=0, requests=[req_0], size=1) - batch_1 = generate_pb2.Batch(id=1, requests=[req_1], size=1) - return [ - CausalLMBatch.from_pb( - b, default_tokenizer, torch.float32, torch.device("hpu") - ) for b in [batch_0, batch_1] - ] - - -def test_single_grammar_batch( - default_causal_lm, default_causal_lm_grammar_batch, default_grammar_response -): - counter = 0 - batch = default_causal_lm_grammar_batch - - # prefill request - generations, next_batch, _ = default_causal_lm.generate_token([batch]) - - # generate untill done - while next_batch is not None: - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 1 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] - counter += 1 - print(generations[0].generated_text.text) - - -def test_multi_grammar_batches( - default_causal_lm, default_two_causal_lm_grammar_batches, default_grammar_response -): - counter_0, counter_1 = 0, 0 - batch_0, batch_1 = default_two_causal_lm_grammar_batches - - # prefill first request - generations, next_batch, _ = default_causal_lm.generate_token([batch_0]) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 1 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0] - counter_0 += 1 - - # prefill second request - generations, next_batch_1, _ = default_causal_lm.generate_token([batch_1]) - - # concatenate and generate - generations, next_batch, _ = default_causal_lm.generate_token([next_batch, next_batch_1]) - assert len(generations) == 2 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0] - assert generations[1].tokens.token_ids[0] == default_grammar_response[counter_1] - counter_0 += 1 - counter_1 += 1 - - # generate untill first request is done - while generations[0].generated_text is None: - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 2 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0] - assert generations[1].tokens.token_ids[0] == default_grammar_response[counter_1] - counter_0 += 1 - counter_1 += 1 - - # filter finished request - response = generations[0].generated_text.text - next_batch = next_batch.filter([next_batch.requests[1].data.id]) - - # generate last tokens for second request - while next_batch is not None: - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 1 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_1] - counter_1 += 1 - - assert response == generations[0].generated_text.text - - -def test_grammar_and_causal_batch( - default_causal_lm, default_causal_lm_grammar_batch, default_causal_lm_batch, default_grammar_response -): - counter = 0 - generations, next_batch, _ = default_causal_lm.generate_token([default_causal_lm_grammar_batch]) - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 1 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] - counter += 1 - - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 1 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] - counter += 1 - - # prefill second request - generations, next_batch_1, _ = default_causal_lm.generate_token([default_causal_lm_batch]) - - # concatenate and generate - generations, next_batch, _ = default_causal_lm.generate_token([next_batch, next_batch_1]) - assert len(generations) == 2 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] - counter += 1 - - # generate untill second request is done - for _ in range( - next_batch.requests[1].stopping_criteria.max_new_tokens - 1 - ): - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 2 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] - counter += 1 - - # filter finished request - assert len(generations) == 2 - assert ( - generations[1].request_id == next_batch.requests[1].data.id - ) - assert ( - generations[1].generated_text.generated_tokens == next_batch.requests[1].stopping_criteria.max_new_tokens - ) - assert generations[1].generated_text.text == "ing the effect of a new method for the detection" - next_batch = next_batch.filter([next_batch.requests[0].data.id]) - - # generate untill done - while next_batch is not None: - generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) - assert len(generations) == 1 - assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] - counter += 1 diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index ba9f5578..735ab5eb 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -61,7 +61,6 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni ) -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): batch = default_seq2seq_lm_batch sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) @@ -93,18 +92,15 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): assert batch.max_decoder_input_length == batch.decoder_input_lengths[0] -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch): with pytest.raises(ValueError): Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch]) -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_seq2seq_lm_batch_type(default_seq2seq_lm): assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) generations, next_batch, _ = default_seq2seq_lm.generate_token( @@ -172,7 +168,6 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert generations[0].request_id == 0 -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_seq2seq_lm_generate_token_completion( default_seq2seq_lm, default_seq2seq_lm_batch ): @@ -190,7 +185,6 @@ def test_seq2seq_lm_generate_token_completion( assert generations[0].generated_text.generated_tokens == 7 -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_seq2seq_lm_generate_token_completion_multi( default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch ): @@ -228,7 +222,6 @@ def test_seq2seq_lm_generate_token_completion_multi( assert generations[0].generated_text.generated_tokens == 7 -@pytest.mark.skip("seq2seq model not enabled on HPU yet") def test_batch_concatenate( default_seq2seq_lm, default_seq2seq_lm_batch, diff --git a/server/tests/models/test_starcoder.py b/server/tests/models/test_starcoder.py deleted file mode 100644 index 05bdd72b..00000000 --- a/server/tests/models/test_starcoder.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import pytest -import torch - -from copy import copy - -from text_generation_server.pb import generate_pb2 -from text_generation_server.models import get_model -from text_generation_server.models.starcoder import StarCoderCausalLMBatch -from text_generation_server.models.causal_lm import ( - PREFILL_BATCH_BUCKET_SIZE, - PAD_SEQUENCE_TO_MULTIPLE_OF, - MAX_TOTAL_TOKENS, - BATCH_BUCKET_SIZE, -) -PAD_TOKEN=0 - - -@pytest.fixture(scope="session") -def default_starcoder(): - return get_model("bigcode/starcoder", None, None, None, None) - - -@pytest.fixture(scope="session") -def default_tokenizer(default_starcoder): - default_starcoder.tokenizer.pad_token_id = PAD_TOKEN - return default_starcoder.tokenizer - - -@pytest.fixture -def default_pb_request(default_pb_parameters, default_pb_stop_parameters): - return generate_pb2.Request( - id=0, - inputs="Test", - prefill_logprobs=True, - truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, - parameters=default_pb_parameters, - stopping_parameters=default_pb_stop_parameters, - ) - - -@pytest.fixture -def default_pb_batch(default_pb_request): - return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) - - -@pytest.fixture -def default_starcoder_batch(default_pb_batch, default_tokenizer): - return StarCoderCausalLMBatch.from_pb( - default_pb_batch, default_tokenizer, torch.float32, torch.device("hpu") - ) - - -@pytest.fixture -def default_multi_requests_starcoder_batch(default_pb_request, default_tokenizer): - req_0 = copy(default_pb_request) - req_0.id = 1 - req_1 = default_pb_request - req_1.id = 2 - req_1.stopping_parameters.max_new_tokens = 5 - - batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) - return StarCoderCausalLMBatch.from_pb( - batch_pb, default_tokenizer, torch.float32, torch.device("hpu") - ) - - -def test_starcoder_batch_type(default_starcoder): - assert default_starcoder.batch_type == StarCoderCausalLMBatch - - -def test_batch_from_pb(default_pb_batch, default_starcoder_batch): - batch = default_starcoder_batch - - assert batch.batch_id == default_pb_batch.id - assert len(batch.requests) == len(default_pb_batch.requests) - - for r in range(0,len(default_pb_batch.requests)): - assert batch.requests[r].data == default_pb_batch.requests[r] - - # For Gaudi we are adding padding of multiplication of bucket size - size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE - - assert len(batch.input_ids) == size_of_padded_to_bucket - assert batch.input_ids.shape == torch.Size([4, 128]) - - assert batch.input_ids[0][-2] == 1006 - assert batch.input_ids[1][-2] == 49 - assert batch.input_ids[2][-2] == 49 - assert batch.attention_mask[0][-2] == 1 - assert batch.attention_mask[1][-2] == 1 - assert batch.attention_mask[2][-2] == 1 - assert torch.all(batch.attention_mask[0, :-3] == 0) - - assert batch.past_key_values is None - assert all( - [ - torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0]) - for input_ids, request in zip(batch.input_ids, batch.requests) - ] - ) - - assert len(batch) == default_pb_batch.size - - assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate - - -def test_starcoder_generate_token(default_starcoder, default_starcoder_batch): - - sequence_length = len(default_starcoder_batch.requests[0].all_input_ids) - generations, next_batch, _ = default_starcoder.generate_token([default_starcoder_batch]) - padding = next_batch.requests[0].stopping_criteria.max_new_tokens - - assert isinstance(next_batch, StarCoderCausalLMBatch) - assert len(next_batch.attention_mask[0]) == PAD_SEQUENCE_TO_MULTIPLE_OF - assert next_batch.requests[0].all_input_ids[-padding-2] == 1006 - - assert torch.all(next_batch.requests[0].all_input_ids[-padding-1:] == PAD_TOKEN) - assert torch.all(next_batch.requests[0].all_input_ids[:-padding-3] == PAD_TOKEN) - - generations, next_batch, _ = default_starcoder.generate_token([default_starcoder_batch]) - assert torch.all(next_batch.attention_mask[0][PAD_SEQUENCE_TO_MULTIPLE_OF-2:PAD_SEQUENCE_TO_MULTIPLE_OF] == 1) - assert torch.all(next_batch.attention_mask[0][:PAD_SEQUENCE_TO_MULTIPLE_OF-3] == 0) - assert torch.all(next_batch.attention_mask[0][PAD_SEQUENCE_TO_MULTIPLE_OF+1:] == 0) - - assert next_batch.requests[0].all_input_ids[-padding-2] == 1006 - assert next_batch.requests[0].all_input_ids[-padding-1] == 26 - assert torch.all(next_batch.requests[0].all_input_ids[-padding:] == PAD_TOKEN) - assert torch.all(next_batch.requests[0].all_input_ids[:-padding-3] == PAD_TOKEN) - - assert next_batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF - assert next_batch.max_input_length == next_batch.input_length - - assert next_batch.past_key_values is not None - assert all( - [p[0].shape == (MAX_TOTAL_TOKENS, 256) for p in next_batch.past_key_values] - ) - assert all( - [p[1].shape == (MAX_TOTAL_TOKENS, 256) for p in next_batch.past_key_values] - ) - assert all([generation.generated_text is None for generation in generations]) - assert all([len(generation.prefill_tokens) == PAD_SEQUENCE_TO_MULTIPLE_OF-1 for generation in generations]) - assert all([generation.tokens.token_ids[0] == 26 for generation in generations]) - assert all([generation.tokens.texts[0] == "(" for generation in generations]) - assert generations[0].request_id == 0 - - -def test_starcoder_generate_token_completion( - default_starcoder, default_starcoder_batch -): - - next_batch = default_starcoder_batch - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - - for _ in range(default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - 1): - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert len(generations) == len(next_batch) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - - assert next_batch is None - - assert len(generations) == 1 - assert generations[0].generated_text.text == '(self):\n """\n Test that the test' - assert generations[0].request_id == default_starcoder_batch.requests[0].data.id - assert ( - generations[0].generated_text.generated_tokens - == default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - ) - - -def test_starcoder_generate_token_completion_multi( - default_starcoder, default_multi_requests_starcoder_batch -): - next_batch = default_multi_requests_starcoder_batch - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - - for i in range( - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - 1 - ): - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert len(generations) == len(next_batch) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert next_batch is not None - - assert len(generations) == 2 - assert generations[1].generated_text.text == '(self):\n """' - assert ( - generations[1].request_id - == default_multi_requests_starcoder_batch.requests[1].data.id - ) - assert ( - generations[1].generated_text.generated_tokens - == default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - ) - - next_batch = next_batch.filter([next_batch.requests[0].data.id]) - - for _ in range( - default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - 1 - ): - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert len(generations) == len(next_batch) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - - assert next_batch is None - - assert len(generations) == 1 - assert generations[0].generated_text.text == '(self):\n """\n Test that the test' - assert ( - generations[0].request_id - == default_multi_requests_starcoder_batch.requests[0].data.id - ) - assert ( - generations[0].generated_text.generated_tokens - == default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - ) - - -def test_batch_concatenate( - default_starcoder, default_starcoder_batch, default_multi_requests_starcoder_batch -): - next_batch_0 = default_starcoder_batch - _, next_batch_0, _ = default_starcoder.generate_token([next_batch_0]) - _, next_batch_0, _ = default_starcoder.generate_token([next_batch_0]) - _, next_batch_0, _ = default_starcoder.generate_token([next_batch_0]) - - next_batch_1 = default_multi_requests_starcoder_batch - _, next_batch_1, _ = default_starcoder.generate_token([next_batch_1]) - _, next_batch_1, _ = default_starcoder.generate_token([next_batch_1]) - - # Clone past_key_values before concatenating to compare after, - # because they are removed from the concatenated batches - next_batch_0_past_key_values = [x.clone() for x in next_batch_0.past_key_values] - next_batch_1_past_key_values = [x.clone() for x in next_batch_1.past_key_values] - - next_batch = StarCoderCausalLMBatch.concatenate([next_batch_0, next_batch_1]) - - assert torch.equal(next_batch.requests[0].all_input_ids, next_batch_0.requests[0].all_input_ids) - assert torch.equal(next_batch.requests[1].all_input_ids, next_batch_1.requests[0].all_input_ids) - assert torch.equal(next_batch.requests[2].all_input_ids, next_batch_1.requests[1].all_input_ids) - - - assert torch.all( - next_batch.attention_mask[0:2, -next_batch.right_padding - 2: -next_batch.right_padding] == 1 - ) - assert torch.all( - next_batch.attention_mask[2, -next_batch.right_padding - 3: -next_batch.right_padding] == 1 - ) - assert torch.all( - next_batch.attention_mask[3, -next_batch.right_padding - 2: -next_batch.right_padding] == 1 - ) - - assert torch.all( - next_batch.attention_mask[0:2, :-next_batch.right_padding-2] == 0) - assert torch.all( - next_batch.attention_mask[2, :-next_batch.right_padding-4] == 0) - assert torch.all( - next_batch.attention_mask[3, :-next_batch.right_padding-3] == 0) - - assert next_batch.batch_id == 0 - assert next_batch.input_ids[0,-next_batch.right_padding - 2] == 1006 - assert next_batch.input_ids[0,-next_batch.right_padding - 1] == 26 - - assert next_batch.max_input_length == 129 - - assert torch.all(next_batch.input_ids[0,-next_batch.right_padding:] == PAD_TOKEN) - assert torch.all(next_batch.input_ids[1,-next_batch.right_padding:] == PAD_TOKEN) - assert torch.all(next_batch.input_ids[2,-next_batch.right_padding:] == PAD_TOKEN) - assert torch.all(next_batch.input_ids[3,-next_batch.right_padding:] == PAD_TOKEN) - - assert next_batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF +1 - assert next_batch.max_input_length == PAD_SEQUENCE_TO_MULTIPLE_OF + 1 - - assert next_batch.requests[0] == next_batch_0.requests[0] - assert next_batch.requests[1:] == next_batch_1.requests - - assert next_batch.requests[0].stopping_criteria == next_batch_0.requests[0].stopping_criteria - assert next_batch.requests[1].stopping_criteria == next_batch_1.requests[0].stopping_criteria - assert next_batch.requests[2].stopping_criteria == next_batch_1.requests[1].stopping_criteria - - assert next_batch.past_key_values is not None - - assert all([p[0].shape == (2048, 256) for p in next_batch.past_key_values]) - assert all([p[1].shape == (2048, 256) for p in next_batch.past_key_values]) - - assert next_batch.past_key_values is not None - - for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0_past_key_values[i][0,0,0:128], past[0][1:129][0, 0:128]) - assert torch.equal(next_batch_0_past_key_values[i][0,1,0:128], past[1][1:129][0, 0:128]) - assert torch.equal( - next_batch_1_past_key_values[i][:, :, 0:1][0][0][0], past[0][1:, :][0][0] - ) - - assert torch.equal( - next_batch_1_past_key_values[i][1:, :, 0:1][0][0][0], past[1][1:, :][0][0] - ) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - - for _ in range( - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - 2 - ): - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert len(generations) == len(next_batch) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert next_batch is not None - - assert len(generations) == 3 - assert generations[2].generated_text.text == '(self):\n """' - - assert ( - generations[2].request_id - == default_multi_requests_starcoder_batch.requests[1].data.id - ) - assert ( - generations[2].generated_text.generated_tokens - == default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - ) - - next_batch = next_batch.filter( - [next_batch.requests[0].data.id, next_batch.requests[1].data.id] - ) - - for _ in range( - default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - - 2 - ): - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert len(generations) == len(next_batch) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert next_batch is not None - - assert len(generations) == 2 - assert generations[0].generated_text.text == '(self):\n """\n Test that the test' - assert generations[0].request_id == default_starcoder_batch.requests[0].data.id - assert ( - generations[0].generated_text.generated_tokens - == default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - ) - - next_batch = next_batch.filter([next_batch.requests[1].data.id]) - - for _ in range( - default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - - default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - - 4 - ): - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert len(generations) == len(next_batch) - - generations, next_batch, _ = default_starcoder.generate_token([next_batch]) - assert next_batch is None - - assert len(generations) == 1 - assert generations[0].generated_text.text == '(self):\n """\n Test that the test' - assert ( - generations[0].request_id - == default_multi_requests_starcoder_batch.requests[0].data.id - ) - assert ( - generations[0].generated_text.generated_tokens - == default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - ) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 94d2a8f2..5db32776 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,26 +1,12 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import pytest import torch -from transformers import AutoTokenizer - from text_generation_server.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, FinishReason, batch_top_tokens, - make_tokenizer_optional, ) -@pytest.fixture -def skip_tokenizer_env_var(): - import os - os.environ["SKIP_TOKENIZER_IN_TGI"] = "true" - yield - del os.environ['SKIP_TOKENIZER_IN_TGI'] - - def test_stop_sequence_criteria(): criteria = StopSequenceCriteria("/test;") @@ -100,33 +86,3 @@ def test_batch_top_tokens(): assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] - - -def test_pass_through_tokenizer(skip_tokenizer_env_var): - tokenizer = AutoTokenizer.from_pretrained( - 'meta-llama/Llama-2-7b-chat-hf', - revision=None, - padding_side="left", - truncation_side="left", - ) - tokenizer.pad_token_id = 2 - make_tokenizer_optional(tokenizer) - - input = ["1, 1724, 338, 6483, 6509, 29973", "?"] - tokenized_inputs = tokenizer( - input, - return_tensors="pt", - padding="max_length", - return_token_type_ids=False, - truncation=True, - max_length=1024, - ) - assert tokenized_inputs['input_ids'].size() == torch.Size([2, 1024]) - assert torch.equal(tokenized_inputs['input_ids'][0][1018:], torch.tensor([1, 1724, 338, 6483, 6509, 29973])) - assert torch.equal(tokenized_inputs['input_ids'][1][1023:], torch.tensor([tokenizer.pad_token_id])) - decoded_tokens = tokenizer.decode(tokenized_inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - assert decoded_tokens.split(',')[1018:] == ['1', '1724', '338', '6483', '6509', '29973'] - - -if __name__ == "__main__": - test_pass_through_tokenizer() diff --git a/server/tests/utils/test_watermark.py b/server/tests/utils/test_watermark.py index c7c83e8c..5d0aaf05 100644 --- a/server/tests/utils/test_watermark.py +++ b/server/tests/utils/test_watermark.py @@ -1,8 +1,7 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. +# test_watermark_logits_processor.py import os import numpy as np -import pytest import torch from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -11,71 +10,35 @@ GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) DELTA = os.getenv("WATERMARK_DELTA", 2.0) -@pytest.fixture -def hpu_device(): - return torch.device("hpu") - - -@pytest.fixture -def input_ids_list(): - return [101, 2036, 3731, 102, 2003, 103] - - -@pytest.fixture -def input_ids_tensor(hpu_device): - return torch.tensor( - [[101, 2036, 3731, 102, 2003, 103]], - dtype=torch.int64, - device=hpu_device - ) - - -@pytest.fixture -def scores(hpu_device): - return torch.tensor( - [[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]], - device=hpu_device - ) - - -def test_seed_rng(input_ids_list, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - processor._seed_rng(input_ids_list) +def test_seed_rng(): + input_ids = [101, 2036, 3731, 102, 2003, 103] + processor = WatermarkLogitsProcessor() + processor._seed_rng(input_ids) assert isinstance(processor.rng, torch.Generator) -def test_seed_rng_tensor(input_ids_tensor, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - processor._seed_rng(input_ids_tensor) - assert isinstance(processor.rng, torch.Generator) - - -def test_get_greenlist_ids(input_ids_list, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - result = processor._get_greenlist_ids(input_ids_list, 10, hpu_device) +def test_get_greenlist_ids(): + input_ids = [101, 2036, 3731, 102, 2003, 103] + processor = WatermarkLogitsProcessor() + result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu")) assert max(result) <= 10 assert len(result) == int(10 * 0.5) -def test_get_greenlist_ids_tensor(input_ids_tensor, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - result = processor._get_greenlist_ids(input_ids_tensor, 10, hpu_device) - assert max(result) <= 10 - assert len(result) == int(10 * 0.5) - - -def test_calc_greenlist_mask(scores, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - greenlist_token_ids = torch.tensor([2, 3], device=hpu_device) +def test_calc_greenlist_mask(): + processor = WatermarkLogitsProcessor() + scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) + greenlist_token_ids = torch.tensor([2, 3]) result = processor._calc_greenlist_mask(scores, greenlist_token_ids) assert result.tolist() == [[False, False, False, False], [False, False, True, True]] assert result.shape == scores.shape -def test_bias_greenlist_logits(scores, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) +def test_bias_greenlist_logits(): + processor = WatermarkLogitsProcessor() + scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) green_tokens_mask = torch.tensor( - [[False, False, True, True], [False, False, False, True]], device=hpu_device + [[False, False, True, True], [False, False, False, True]] ) greenlist_bias = 2.0 result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias) @@ -83,13 +46,9 @@ def test_bias_greenlist_logits(scores, hpu_device): assert result.shape == scores.shape -def test_call(input_ids_list, scores, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - result = processor(input_ids_list, scores) - assert result.shape == scores.shape - - -def test_call_tensor(input_ids_tensor, scores, hpu_device): - processor = WatermarkLogitsProcessor(device=hpu_device) - result = processor(input_ids_tensor, scores) +def test_call(): + input_ids = [101, 2036, 3731, 102, 2003, 103] + processor = WatermarkLogitsProcessor() + scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) + result = processor(input_ids, scores) assert result.shape == scores.shape diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 74b87024..ad623ccc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -1,8 +1,4 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import os -import psutil -import signal import sys import typer @@ -18,6 +14,8 @@ app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" + bitsandbytes_nf4 = "bitsandbytes-nf4" + bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" eetq = "eetq" @@ -44,6 +42,9 @@ def serve( otlp_endpoint: Optional[str] = None, ): if sharded: + assert ( + os.getenv("RANK", None) is not None + ), "RANK must be set when sharded is True" assert ( os.getenv("WORLD_SIZE", None) is not None ), "WORLD_SIZE must be set when sharded is True" @@ -76,73 +77,26 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value - dtype = "bfloat16" if dtype is None else dtype.value - - logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - - if sharded: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" - num_shard = int(os.getenv("WORLD_SIZE", "1")) - logger.info("CLI SHARDED = {}".format(num_shard)) - import subprocess - - cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" - cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" - cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" - if speculate is not None: - cmd += f"--speculate {speculate}" - logger.info("CLI server start deepspeed ={} ".format(cmd)) - sys.stdout.flush() - sys.stderr.flush() - with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - do_terminate = False - current_handler = signal.getsignal(signal.SIGTERM) - def terminate_handler(sig, frame): - nonlocal do_terminate - do_terminate = True - if callable(current_handler): - current_handler(sig, frame) - - signal.signal(signal.SIGTERM, terminate_handler) - - finished = False - while not finished: - try: - if do_terminate: - parent = psutil.Process(proc.pid) - all_procs = parent.children(recursive=True) + [parent] - for p in all_procs: - try: - p.terminate() - except psutil.NoSuchProcess: - pass - _, alive = psutil.wait_procs(all_procs, timeout=30) - for p in alive: - p.kill() - - do_terminate = False - - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - pass - else: - finished = True - - sys.stdout.flush() - sys.stderr.flush() - if proc.returncode != 0: - logger.error(f"{cmd} exited with status = {proc.returncode}") - return proc.returncode - else: - server.serve( - model_id, - revision, - sharded, - speculate, - dtype, - trust_remote_code, - uds_path + dtype = None if dtype is None else dtype.value + if dtype is not None and quantize not in { + None, + "bitsandbytes", + "bitsandbytes-nf4", + "bitsandbytes-fp4", + }: + raise RuntimeError( + "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) + server.serve( + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + ) @app.command() @@ -309,24 +263,30 @@ def download_weights( ) # Safetensors final filenames - local_st_files = [p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files] + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" + for p in local_pt_files + ] try: import transformers - from transformers import AutoConfig + import json - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - ) - architecture = config.architectures[0] + if is_local_model: + config_filename = os.path.join(model_id, "config.json") + else: + config_filename = hf_hub_download( + model_id, revision=revision, filename="config.json" + ) + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] class_ = getattr(transformers, architecture) # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) - discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) - except Exception: + except Exception as e: discard_names = [] # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files, discard_names) @@ -344,6 +304,8 @@ def quantize( percdamp: float = 0.01, act_order: bool = False, ): + if revision is None: + revision = "main" download_weights( model_id=model_id, revision=revision, @@ -357,6 +319,7 @@ def quantize( bits=4, groupsize=128, output_dir=output_dir, + revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, diff --git a/server/text_generation_server/habana_quantization_env.py b/server/text_generation_server/habana_quantization_env.py deleted file mode 100644 index 3c06fd09..00000000 --- a/server/text_generation_server/habana_quantization_env.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import os - -quant_config = os.getenv("QUANT_CONFIG", "") -is_quantization_enabled = quant_config != "" - -if is_quantization_enabled: - os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") - os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") - os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") - os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault( - "UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") - os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") - - -def prepare_model_for_quantization(model): - if is_quantization_enabled: - if os.getenv("USE_INC", "1") != "0": - from neural_compressor.torch.quantization import FP8Config, convert - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - else: - import habana_quantization_toolkit - habana_quantization_toolkit.prep_model(model) - return model diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 05339282..57df1725 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -1,5 +1,3 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import torch import grpc @@ -8,8 +6,6 @@ from grpc_status import rpc_status from grpc_interceptor.server import AsyncServerInterceptor from loguru import logger from typing import Callable, Any -import traceback -import os class ExceptionInterceptor(AsyncServerInterceptor): @@ -24,7 +20,6 @@ class ExceptionInterceptor(AsyncServerInterceptor): response = method(request_or_iterator, context) return await response except Exception as err: - trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else '' method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") @@ -35,10 +30,8 @@ class ExceptionInterceptor(AsyncServerInterceptor): if torch.cuda.is_available(): torch.cuda.empty_cache() - from .utils.debug import dbg_trace - dbg_trace('EXCEPTION', traceback.format_exc()) await context.abort_with_status( rpc_status.to_status( - status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) + status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) ) ) diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 80836a95..321ced97 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -119,6 +119,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, temp_dq, ) + else: + RuntimeError("Cannot create handle") DEVICE = None diff --git a/server/text_generation_server/layers/gptq/exllamav2.py.rej b/server/text_generation_server/layers/gptq/exllamav2.py.rej deleted file mode 100644 index cde7b73a..00000000 --- a/server/text_generation_server/layers/gptq/exllamav2.py.rej +++ /dev/null @@ -1,10 +0,0 @@ -diff a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py (rejected hunks) -@@ -119,6 +119,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): - none_tensor, - temp_dq, - ) -+ else: -+ RuntimeError("Cannot create handle") - - - DEVICE = None diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/quant_linear.py index b52ceb0f..f60758b6 100644 --- a/server/text_generation_server/layers/gptq/quant_linear.py +++ b/server/text_generation_server/layers/gptq/quant_linear.py @@ -1,356 +1,356 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_fwd - -import triton -import triton.language as tl -from . import custom_autotune - - -# code based https://github.com/fpgaminer/GPTQ-triton -@custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, -) -@triton.jit -def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # eventually avoid overflow - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) - return output - - -class QuantLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - return output - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_fwd + +import triton +import triton.language as tl +from . import custom_autotune + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load( + scales_ptrs + g_idx[:, None] * stride_scales + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + g_idx[:, None] * stride_zeros + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) & maxq # eventually avoid overflow + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty( + (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 + ) + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + return output + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 30bd8f6c..d4a325a9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,4 +1,5 @@ import torch +import enum import os from loguru import logger @@ -8,36 +9,269 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path -# Needed to properly setup habana_frameworks -import text_generation_server.habana_quantization_env as hq_env - from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.bloom import BLOOM -from text_generation_server.models.starcoder import StarCoder -from text_generation_server.models.vlm_causal_lm import VlmCausalLM -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) +from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.models.bloom import BLOOMSharded +from text_generation_server.models.mpt import MPTSharded +from text_generation_server.models.seq2seq_lm import Seq2SeqLM +from text_generation_server.models.rw import RW +from text_generation_server.models.opt import OPTSharded +from text_generation_server.models.galactica import GalacticaSharded +from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.t5 import T5Sharded +from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.phi import Phi +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True - -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True # Disable gradients torch.set_grad_enabled(False) +__all__ = [ + "Model", + "BLOOMSharded", + "CausalLM", + "GalacticaSharded", + "Seq2SeqLM", + "SantaCoder", + "OPTSharded", + "T5Sharded", + "get_model", +] + +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." + +FLASH_ATTENTION = True + +try: + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_gpt2 import FlashGPT2 + from text_generation_server.models.flash_neox import FlashNeoXSharded + from text_generation_server.models.flash_llama import ( + FlashLlama, + ) + from text_generation_server.models.flash_qwen2 import ( + FlashQwen2, + ) + from text_generation_server.models.flash_cohere import ( + FlashCohere, + ) + from text_generation_server.models.flash_gemma import ( + FlashGemma, + ) + from text_generation_server.models.pali_gemma import ( + PaliGemma, + ) + from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, + ) + from text_generation_server.models.idefics import IDEFICSSharded + from text_generation_server.models.llava_next import LlavaNext + from text_generation_server.models.idefics2 import Idefics2 + from text_generation_server.models.flash_mistral import FlashMistral + from text_generation_server.models.flash_mixtral import FlashMixtral + from text_generation_server.models.flash_phi import FlashPhi + from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 + from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.utils.flash_attn import ( + HAS_FLASH_ATTN_V2_CUDA, + HAS_FLASH_ATTN_V2_ROCM, + ) +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") + FLASH_ATTENTION = False + HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False + +if FLASH_ATTENTION: + __all__.append(FlashGPT2) + __all__.append(FlashNeoXSharded) + __all__.append(FlashRWSharded) + __all__.append(FlashSantacoderSharded) + __all__.append(FlashLlama) + __all__.append(IDEFICSSharded) + __all__.append(FlashMistral) + __all__.append(FlashMixtral) + __all__.append(FlashDbrx) + __all__.append(FlashPhi) + __all__.append(FlashQwen2) + __all__.append(FlashStarcoder2) + __all__.append(FlashGemma) + __all__.append(FlashCohere) + +MAMBA_AVAILABLE = True +try: + from text_generation_server.models.mamba import Mamba +except ImportError as e: + logger.warning(f"Could not import Mamba: {e}") + MAMBA_AVAILABLE = False + +if MAMBA_AVAILABLE: + __all__.append(Mamba) + + +class ModelType(enum.Enum): + IDEFICS2 = { + "type": "idefics2", + "name": "Idefics 2", + "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", + "multimodal": True, + } + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + "multimodal": True, + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + COHERE = { + "type": "cohere", + "name": "Cohere", + "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", + } + DBRX = { + "type": "dbrx", + "name": "Dbrx", + "url": "https://huggingface.co/databricks/dbrx-instruct", + } + MAMBA = { + "type": "ssm", + "name": "Mamba", + "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", + } + MIXTRAL = { + "type": "mixtral", + "name": "Mixtral", + "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", + } + GPT_BIGCODE = { + "type": "gpt_bigcode", + "name": "Gpt Bigcode", + "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + FALCON = { + "type": "falcon", + "name": "Falcon", + "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", + } + STARCODER2 = { + "type": "starcoder2", + "name": "StarCoder 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + OPT = { + "type": "opt", + "name": "Opt", + "url": "https://huggingface.co/facebook/opt-6.7b", + } + T5 = { + "type": "t5", + "name": "T5", + "url": "https://huggingface.co/google/flan-t5-xxl", + } + GALACTICA = { + "type": "galactica", + "name": "Galactica", + "url": "https://huggingface.co/facebook/galactica-120b", + } + SANTACODER = { + "type": "santacoder", + "name": "SantaCoder", + "url": "https://huggingface.co/bigcode/santacoder", + } + BLOOM = { + "type": "bloom", + "name": "Bloom", + "url": "https://huggingface.co/bigscience/bloom-560m", + } + MPT = { + "type": "mpt", + "name": "Mpt", + "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", + } + GPT2 = { + "type": "gpt2", + "name": "Gpt2", + "url": "https://huggingface.co/openai-community/gpt2", + } + GPT_NEOX = { + "type": "gpt_neox", + "name": "Gpt Neox", + "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", + } + IDEFICS = { + "type": "idefics", + "name": "Idefics", + "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", + "multimodal": True, + } + + +__GLOBALS = locals() +for data in ModelType: + __GLOBALS[data.name] = data.value["type"] + def get_model( model_id: str, revision: Optional[str], + sharded: bool, + quantize: Optional[str], speculate: Optional[int], - dtype: Optional[torch.dtype], + dtype: Optional[str], trust_remote_code: bool, ) -> Model: - adapt_transformers_to_gaudi() + if dtype is None: + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None + elif dtype == "float16": + dtype = torch.float16 + elif dtype == "bfloat16": + dtype = torch.bfloat16 + else: + raise RuntimeError(f"Unknown dtype {dtype}") if speculate is not None: set_speculate(speculate) @@ -152,38 +386,535 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict["model_type"] + if model_type is None: + # TODO: fix how we determine model type for Mamba + if "ssm_cfg" in config_dict: + # *only happens in Mamba case + model_type = "ssm" + else: + raise RuntimeError( + f"Could not determine model type for {model_id} revision {revision}" + ) + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq"}: + logger.info(f"Auto selecting quantization method {method}") + quantize = method + else: + logger.info(f"Unknown quantization method {method}") - if model_type == "gpt_bigcode": - return StarCoder(model_id, revision, dtype) - - if model_type == "bloom": - return BLOOM( + if model_type == MAMBA: + return Mamba( model_id, revision, + quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "llava_next": - return VlmCausalLM( - model_class=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, + if model_id.startswith("facebook/galactica"): + return GalacticaSharded( + model_id, + revision, + quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) + if ( + model_type == GPT_BIGCODE + or model_type == GPT2 + and model_id.startswith("bigcode/") + ): + if FLASH_ATTENTION: + return FlashSantacoderSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") + ) + else: + return SantaCoder( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == BLOOM: + return BLOOMSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == MPT: + return MPTSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == GPT2: + if FLASH_ATTENTION: + try: + return FlashGPT2( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + except RuntimeError as e: + # Lots of legacy models with various weight names. + logger.warning(f"Couldn't load flash gpt2 variant: {e}") + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == GPT_NEOX: + if FLASH_ATTENTION: + return FlashNeoXSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + return GPTNeoxSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + elif model_type == PHI: + if FLASH_ATTENTION: + return FlashPhi( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + elif model_type == "phi-msft": + if FLASH_ATTENTION: + raise NotImplementedError( + "Legacy phi-msft is not supported with Flash Attention" + ) + else: + return Phi( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: + if FLASH_ATTENTION: + return FlashLlama( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == GEMMA: + if FLASH_ATTENTION: + return FlashGemma( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == COHERE: + if FLASH_ATTENTION: + return FlashCohere( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == DBRX: + if FLASH_ATTENTION: + return FlashDbrx( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: + if sharded: + if FLASH_ATTENTION: + if config_dict.get("alibi", False): + raise NotImplementedError("sharded is not supported for this model") + return FlashRWSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) + else: + if FLASH_ATTENTION and not config_dict.get("alibi", False): + return FlashRWSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + return RW( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == MISTRAL: + sliding_window = config_dict.get("sliding_window", -1) + if ( + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): + return FlashMistral( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == MIXTRAL: + sliding_window = config_dict.get("sliding_window", -1) + if ( + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): + return FlashMixtral( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == STARCODER2: + sliding_window = config_dict.get("sliding_window", -1) + if ( + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): + return FlashStarcoder2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == QWEN2: + sliding_window = config_dict.get("sliding_window", -1) + if ( + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): + return FlashQwen2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == OPT: + return OPTSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == T5: + return T5Sharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == IDEFICS: + if FLASH_ATTENTION: + return IDEFICSSharded( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == IDEFICS2: + if FLASH_ATTENTION: + return Idefics2( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == "paligemma": + if FLASH_ATTENTION: + return PaliGemma( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + + if model_type == LLAVA_NEXT: + if FLASH_ATTENTION: + return LlavaNext( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) + + if sharded: + raise NotImplementedError("sharded is not supported for AutoModel") + if quantize == "gptq": + raise NotImplementedError( + "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + if quantize == "awq": + raise NotImplementedError("awq quantization is not supported for AutoModel") + elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): + raise NotImplementedError("4bit quantization is not supported for AutoModel") + elif quantize == "eetq": + raise NotImplementedError("Eetq quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: + return Seq2SeqLM( + model_id, + revision, + quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) + auto_map = config_dict.get("auto_map", None) + if trust_remote_code and auto_map is not None: + if "AutoModelForCausalLM" in auto_map.keys(): + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if "AutoModelForSeq2SeqLM" in auto_map.keys(): + return Seq2SeqLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise ValueError(f"Unsupported model type {model_type}") diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 6fe64374..65c9f317 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -1,14 +1,25 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import torch +import torch.distributed from typing import Optional, Type -from transformers import PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoConfig, + PreTrainedTokenizerBase, +) +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) class BloomCausalLMBatch(CausalLMBatch): @@ -20,33 +31,88 @@ class BloomCausalLMBatch(CausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb( - pb=pb, - tokenizer=tokenizer, - dtype=dtype, - device=device, - ) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch -class BLOOM(CausalLM): +class BLOOMSharded(CausalLM): def __init__( self, model_id: str, revision: Optional[str] = None, + quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - super(BLOOM, self).__init__( - model_id=model_id, + + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, - speculator=speculator, - dtype=dtype, + padding_side="left", + truncation_side="left", trust_remote_code=trust_remote_code, ) + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + slow_but_exact=False, + tp_parallel=True, + trust_remote_code=trust_remote_code, + ) + config.pad_token_id = 3 + config.quantize = quantize + config.speculator = speculator + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + prefix="transformer", + ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id, revision) + + model = BloomForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): + outputs, speculative_logits = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) + + logits = outputs.logits + return logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f814a8e2..81a02163 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,38 +1,13 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import bisect -from dataclasses import dataclass -from functools import wraps -import itertools -import math -import os -import tempfile -import time -from typing import Dict, List, Optional, Tuple, Type - import torch -import torch._dynamo -from loguru import logger +import time + +from dataclasses import dataclass from opentelemetry import trace +from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from typing import Optional, Tuple, List, Type, Dict -import text_generation_server.habana_quantization_env as hq_env -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from optimum.habana.checkpoint_utils import ( - get_repo_root, - model_on_meta, - write_checkpoints_json, -) -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - PreTrainedTokenizerBase, - AutoConfig, -) - -from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, Tokens, @@ -40,419 +15,55 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - StoppingCriteria, - make_tokenizer_optional, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -from text_generation_server.utils.debug import dbg_trace -from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling tracer = trace.get_tracer(__name__) -MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048)) -BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) -PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) - - -def torch_compile_for_eager(func): - if LAZY_MODE == 1: - return func - return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True}) - - -def round_up(number, k): - return (number + k - 1) // k * k - - -def to_tensor_indices(indices, device): - return torch.tensor(indices, dtype=torch.long, device=device) - - -def calculate_chunks(offset): - result = [] - while offset != 0: - sign = 1 if offset > 0 else -1 - best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] - result.append(best_chunk) - offset = offset - best_chunk - return result - - -def biggest_single_chunk(offset): - if offset != 0: - idx = bisect.bisect(CHUNK_SIZES, abs(offset)) - return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) - else: - return 0 - - -@torch_compile_for_eager -def grouped_pad(tensor_groups, dims, values): - grouped_result = [] - for tensors, dim, value in zip(tensor_groups, dims, values): - padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 - if padding > 0: - assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}' - pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors] - else: - result = [t for t in tensors] - grouped_result.append(result) - htorch.core.mark_step() - return grouped_result - - -@torch_compile_for_eager -def roll(tensor, chunk, dim, merge_graphs): - if dim is None: - return tensor - tensor = torch.roll(tensor, chunk, dim) - if not merge_graphs: - htorch.core.mark_step() - return tensor - - -def grouped_roll(tensor_groups, chunk, dims, merge_graphs): - tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)] - if merge_graphs: - htorch.core.mark_step() - return tensor_groups - - -@torch_compile_for_eager -def grouped_shift(tensor_groups, dims, offset, merge_graphs): - chunks = calculate_chunks(offset) - for c in chunks: - tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) - return tensor_groups - - -def move(dst_tensors, dst_indices, src_tensors): - bs_dim = 0 - num_indices = dst_indices.size(0) - for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): - if src_t.size(bs_dim) != num_indices: - src_t = torch.narrow(src_t, bs_dim, 0, num_indices) - dst_t.index_copy_(bs_dim, dst_indices, src_t) - htorch.core.mark_step() - - -def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): - for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): - move(dst_tensors, dst_indices, src_tensors) - - -@torch_compile_for_eager -def extend_tensor(tensor, padding, dim): - result = torch.cat([tensor, padding], dim=dim) - htorch.core.mark_step() - return result - - -@torch_compile_for_eager -def extend_batch(tensors, target_bs, dim): - diff = target_bs - tensors[0].size(dim) - # TODO: add support for shrinking bs - if diff <= 0: - return tensors - shape = list(tensors[0].shape) - shape[dim] = diff - padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) - tensors = [extend_tensor(t, padding, dim) for t in tensors] - return tensors - - -def grouped_extend_batch(tensor_groups, target_bs, bs_dims): - tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)] - return tensor_groups - - -@torch_compile_for_eager -def merge(tensor_group): - tensor_group = [torch.stack(tensor_group)] - htorch.core.mark_step() - return tensor_group - - -@torch_compile_for_eager -def split(tensor_group, clone_data): - tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] - if clone_data: - tensor_group = [t.clone() for t in tensor_group] - htorch.core.mark_step() - return tensor_group - - -def remove_kv_cache_from_output(module): - orig_fwd = module.forward - - @wraps(orig_fwd) - def forward(*args, **kwargs): - if kwargs["past_key_values"] is not None: - kwargs["return_dict"] = False - output = orig_fwd(*args, **kwargs) - first_value, second_value, *_ = output - if first_value.nelement() < 2: - return second_value - else: - return first_value - else: - kwargs["return_dict"] = True - return orig_fwd(*args, **kwargs) - - module.forward = forward - return module - - -@dataclass -class CausalLMRequest: - idx: int - data: generate_pb2.Request - input_length: int - prefix_offset: int - read_offset: int - stopping_criteria: StoppingCriteria - - all_input_ids: torch.Tensor - - @classmethod - def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase): - return cls( - idx=idx, - data=data, - input_length=None, - prefix_offset=None, - read_offset=None, - stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer), - all_input_ids=None,) - - def update_idx(self, new_idx): - prev = self.idx - self.idx = new_idx - return (new_idx, prev) - @dataclass class CausalLMBatch(Batch): batch_id: int - requests: List[CausalLMRequest] + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] # Decoder values input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor past_key_values: Optional[List[Tuple]] - merged_kv_cache: bool + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] # Generation helpers - next_token_chooser: HeterogeneousNextTokenChooser + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor - input_length: int + # Metadata used for padding + max_input_length: int + padding_right_offset: int - logits = None - past = None + # Maximum number of tokens this batch will grow to + max_tokens: int + # Past metadata keys_head_dim_last: bool = True def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, - request_ids=[r.data.id for r in self.requests], + request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) - def detach_kv_cache(self): - past_keys = [past[0] for past in self.past_key_values] - past_values = [past[1] for past in self.past_key_values] - del self.past_key_values - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - # TODO: Add support for models that don't store kv_cache in a list - self.past_key_values = list(zip(past_keys, past_values)) - - def merge_kv_cache_if_needed(self, target_bs, offset): - pad_needed = self.seq_length < MAX_TOTAL_TOKENS - shift_needed = offset != 0 - expand_needed = target_bs > self.batch_size - # Very simple heuristic to determine whether we should merge tensors - # this needs tuning for other models/scenarios - small_bs = len(self.past_key_values) > self.batch_size - if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed): - past_keys, past_values = self.detach_kv_cache() - past_keys = merge(past_keys) - past_values = merge(past_values) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = True - - def split_kv_cache_if_needed(self, clone_data): - if self.merged_kv_cache: - past_keys, past_values = self.detach_kv_cache() - past_keys = split(past_keys, clone_data) - past_values = split(past_values, clone_data) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = False - - def get_tensor_groups(self): - past_keys, past_values = self.detach_kv_cache() - seq_dim = -1 - key_dim = -2 if self.keys_head_dim_last else -1 - value_dim = -2 - tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] - # We don't need to align position_ids - seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] - bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) - return tensors, seq_dims, bs_dims - - def set_tensor_groups(self, tensors): - self.input_ids = tensors.pop(0)[0] - self.attention_mask = tensors.pop(0)[0] - self.position_ids = tensors.pop(0)[0] - past_keys = tensors.pop(0) - past_values = tensors.pop(0) - self.attach_kv_cache(past_keys, past_values) - - def realign(self, target_bs, offset, pad_token_id): - tensors, seq_dims, _ = self.get_tensor_groups() - tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) - tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) - self.set_tensor_groups(tensors) - - def expand_bs(self, target_bs): - tensors, _, bs_dims = self.get_tensor_groups() - tensors = grouped_extend_batch(tensors, target_bs, bs_dims) - self.set_tensor_groups(tensors) - - def used_indices(self): - return [req.idx for req in self.requests] - - def update_indices(self, new_indices): - for req, new_idx in zip(self.requests, new_indices): - req.idx = new_idx - return self.used_indices() - - def free_indices_generator(self): - used = set(req.idx for req in self.requests) - return (i for i in range(self.batch_size) if i not in used) - - def move_data(self, src_batches): - dst_tensors, _, dst_dims = self.get_tensor_groups() - free_indices_gen = self.free_indices_generator() - for src_b in src_batches: - dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device) - src_tensors, _, src_dims = src_b.get_tensor_groups() - grouped_move(dst_tensors, dst_indices, src_tensors) - self.set_tensor_groups(dst_tensors) - - @classmethod - def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - - total_requests = sum(len(b) for b in batches) - new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) - batch_id = batches[0].batch_id - device = batches[0].input_ids.device - - input_lengths = [b.input_length for b in batches] - max_input_length = max(input_lengths) - offsets = [max_input_length - b.input_length for b in batches] - - cur_padding = [b.right_padding for b in batches] - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - - moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] - dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = (batches[dst_batch_idx].batch_size < new_bs) - - # TODO: Add support for changing max seq len, i.e. due to output length bucketing - # FIXME: max_seq_len for non optimized code - if len(batches) > 1: - scenario = 'CONCAT' - elif reshape: - scenario = 'RESHAPE' - elif cur_padding[dst_batch_idx] <= 0: - scenario = 'SHIFT' - offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] - max_input_length = max_input_length + offsets[dst_batch_idx] - else: - # Nothing to do - return batches[0] - - dbg_trace( - scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' - f' reqs:{[len(b) for b in batches]}' - f' offsets:{offsets}' - f' input_lengths:{input_lengths}' - f' cur_padding:{cur_padding}' - f' dst_batch:{dst_batch_idx}') - - grouped_requests = [[req for req in batch.requests] for batch in batches] - flat_requests = list(itertools.chain(*grouped_requests)) - - for i in range(len(batches)): - target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size - batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) - batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[i].split_kv_cache_if_needed(i == dst_batch_idx) - batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) - - top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - - parameters = [r.data.parameters for r in flat_requests] - # append the dummy parameters for dummy requests - batch_size = batches[dst_batch_idx].batch_size - parameters = pad_next_token_chooser_parameters(parameters, batch_size) - - # update past grammar states - fsm_grammar_states = [0] * batch_size - for batch in batches: - for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device, - batches[dst_batch_idx].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - input_ids = batches[dst_batch_idx].input_ids - attention_mask = batches[dst_batch_idx].attention_mask - position_ids = batches[dst_batch_idx].position_ids - past_key_values = batches[dst_batch_idx].past_key_values - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=flat_requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - @classmethod def from_pb( cls, @@ -461,152 +72,433 @@ class CausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') - requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} - max_input_length = max(r.data.truncate for r in requests) - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) - # TODO: Add support for sparse batches - top_n_tokens = [r.top_n_tokens for r in pb.requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) - missing_inputs = new_bs - len(requests) - dummy_inputs = ["?"] * missing_inputs - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) tokenized_inputs = tokenizer( - [r.data.inputs for r in requests] + dummy_inputs, + inputs, return_tensors="pt", - padding="longest", + padding=True, return_token_type_ids=False, truncation=True, - max_length=max_input_length, - ) + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) - input_len = tokenized_inputs["input_ids"].shape[1] - - bucket_size = max_input_length - left_padding = max_input_length - input_len - if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_input_length + padding_right_offset) ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - # New input length after left padding - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - - htorch.core.mark_step() + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, - requests=requests, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') - request_ids = set(request_ids) - self.requests = [req for req in self.requests if req.data.id in request_ids] + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + return self @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id) + def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + past_key_values = [] + top_n_tokens_tensor = None + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # We need to slice the attention mask to remove padding from previous steps + # and to remove unused allocated space + left_offset = max_input_length - batch.max_input_length + batch_left_offset = ( + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset + ) + attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values + ] + elif len(batch.past_key_values[0][0].shape) == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if batch.keys_head_dim_last: + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) + else: + # BLOOM case + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) + del past_keys + + start_index = end_index + + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) + del past_values + + # Update values + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) def __len__(self): return len(self.requests) - @property - def max_input_length(self): - return max(req.input_length for req in self.requests) - - @property - def batch_size(self): - return self.attention_mask.size(0) - - @property - def seq_length(self): - return self.attention_mask.size(1) - - @property - def right_padding(self): - return self.seq_length - self.input_length - - # Maximum number of tokens this batch will grow to - @property - def max_tokens(self): - max_total_tokens = self.attention_mask.size(1) - return len(self.requests) * max_total_tokens - class CausalLM(Model): def __init__( self, model_id: str, revision: Optional[str] = None, + quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") - self.prev_bs = 0 + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype - # Create tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, @@ -614,444 +506,153 @@ class CausalLM(Model): truncation_side="left", trust_remote_code=trust_remote_code, ) - make_tokenizer_optional(tokenizer) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - if world_size > 1: - model = self.get_deepspeed_model( - model_id, dtype, revision - ) - model = self.prepare_model_for_quantization(model) - else: - get_repo_root(model_id) - - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained( - model_id - ) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs - ) - model = self.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace( - "TORCH COMPILE", f'Torch compiling of model') - model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - - model = self.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if ( + torch.cuda.is_available() + and torch.cuda.device_count() == 1 + and quantize != "bitsandbytes" + ): + model = model.cuda() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: tokenizer.pad_token_id = model.config.pad_token_id elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) + tokenizer.pad_token_id = model.config.eos_token_id elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2"]: - - if model.config.model_type in ["llama", "mistral", "qwen2"]: - kwargs["attn_softmax_bf16"] = True - kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": - kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": - kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, - rank=rank, - kwargs=kwargs, ) - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - def get_deepspeed_model( - self, - model_id: str, - dtype: torch.dtype, - revision: Optional[str] = None - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = { - "revision": revision - } - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) - else: - get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - write_checkpoints_json(model_id, local_rank, checkpoints_json) - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return { - 'type': rope_scaling, 'factor': float(rope_factor) - } - - def setup_quantization(self, model): - if hq_env.is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - def prepare_model_for_quantization(self, model): - if hq_env.is_quantization_enabled: - if model.config.model_type == "llama": - self.patch_scoped_linear_all_reduce(model) - model = hq_env.prepare_model_for_quantization(model) - return model - - def patch_scoped_linear_all_reduce(self, model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - self.patch_scoped_linear_all_reduce(module) - @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def forward( - self, - input_ids, - attention_mask, - position_ids, - token_idx, - past_key_values: Optional[List[Tuple]] = None, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]] + ]: # Model Forward kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, - "token_idx": token_idx, + "use_cache": True, + "return_dict": True, } - - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama" : - kwargs["lazy_mode"] = LAZY_MODE == 1 - if self.has_position_ids: kwargs["position_ids"] = position_ids - if bypass_hpu_graph != None: - kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - if past_key_values is not None: - return self.model.forward(**kwargs) + outputs = self.model.forward(**kwargs) + if isinstance(outputs, tuple): + outputs, speculative_logits = outputs else: - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values + speculative_logits = None + return outputs.logits, speculative_logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") def generate_token( - self, batches: List[CausalLMBatch] + self, batch: CausalLMBatch ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: start = time.time_ns() + # slice the attention mask to the correct shape + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + + logits, speculative_logits, past = self.forward( + batch.input_ids, + attention_mask, + batch.position_ids, + batch.past_key_values, + ) + # Results generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding - token_idx = torch.tensor(token_idx_scalar).to(self.device) + stopped = True - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append({ - 'next_token_ids': next_token_ids, - 'next_token_logprobs': next_token_logprobs, - }) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append({ - 'req': req, - 'prev_req_idx': req.idx, - 'batch_id': batch_id, - 'seed': batch.next_token_chooser.seeds[req_idx], - 'do_sample': batch.next_token_chooser.do_sample[req_idx], - 'top_n_tokens': batch.top_n_tokens[req_idx], - 'top_token_ids': batch_top_token_ids[req_idx], - 'top_token_logprobs': batch_top_token_logprobs[req_idx], - 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], - }) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # Adjust lengths - batch.input_length += 1 - - # Update position_ids - if prefill: - batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - else: - batch.position_ids += 1 - # Update past key values - if prefill: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) - - scenario = 'PREFILL' if prefill else 'GENERATE' - if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: - self.model.clear_cache() - self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) - dbg_trace( - scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') - assert batch.right_padding > 0, 'No more room for next token!' - - # Execute batch - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, - ) - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - batch.logits = self.forward( - input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, - ) - - htorch.core.mark_step() + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1], -1), + accepted_ids, + ) start_decode = time.time_ns() - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() - prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() - htorch.core.mark_step() + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) - for req_data in requests_to_generate: - req = req_data['req'] - i = req_data['prev_req_idx'] - prev_batch_id = req_data['batch_id'] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] - next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data['do_sample'] - seed = req_data['seed'] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data['top_n_tokens'] - top_token_ids = req_data['top_token_ids'] - top_token_logprobs = req_data['top_token_logprobs'] - grammar_state = req_data['grammar_state'] + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) # Append next token to all tokens - all_input_ids[input_length] = next_token_id + all_input_ids = torch.cat([all_input_ids, next_token_id]) new_input_length = input_length + 1 # Generated token - if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: - next_token_text = '' - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token_id, + next_token_id_squeezed, next_token_text, ) @@ -1063,17 +664,23 @@ class CausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed else: - output_text = self.decode( - all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] - ) + seed = None + generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + output_text, stopping_criteria.current_tokens, reason, seed ) else: generated_text = None @@ -1081,8 +688,12 @@ class CausalLM(Model): # Prefill if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0: new_input_length - 1] + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, @@ -1126,10 +737,10 @@ class CausalLM(Model): request.id, prefill_tokens, Tokens( - [next_token_id], + [next_token_id_squeezed], [next_token_logprob], [next_token_text], - [next_token_id in self.all_special_ids], + [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, top_tokens, @@ -1137,58 +748,37 @@ class CausalLM(Model): generations.append(generation) - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() ) + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: - self.hb_profiler.stop() - else: - self.hb_profiler.step() + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def warmup(self, batches: List[CausalLMBatch]) -> None: - def get_unfinished_requests(requests: List[CausalLMRequest]) -> List[int]: - return [ - request.data.id for request in requests - if request.stopping_criteria.current_tokens < request.stopping_criteria.max_new_tokens - ] - - # prefill - _, prefill_batch, _ = self.generate_token([batches.pop(0)]) - # decode - _, decode_batch, _ = self.generate_token([prefill_batch]) - # shifts - self.shifting_warmup(decode_batch) - - while len(batches) > 0: - # prefill - _, prefill_batch, _ = self.generate_token([batches.pop(0)]) - # concatenate and decode - _, decode_batch, _ = self.generate_token([decode_batch, prefill_batch]) - # filter finished requests - request_ids = get_unfinished_requests(decode_batch.requests) - if len(request_ids) < len(decode_batch.requests): - decode_batch = decode_batch.filter(request_ids) - - def shifting_warmup(self, batch: CausalLMBatch) -> None: - chunk_sizes = CHUNK_SIZES.copy() - chunk_sizes.extend([-chunk for chunk in chunk_sizes]) - - for chunk in chunk_sizes: - batch.merge_kv_cache_if_needed(batch.batch_size, chunk) - batch.realign(batch.batch_size, chunk, 0) - batch.split_kv_cache_if_needed(True) + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 1ef55019..de9673aa 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -21,12 +21,18 @@ import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN -from transformers.models.llava_next.modeling_llava_next import ( - unpad_image, -) -from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. @@ -50,13 +56,100 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext +class LlavaNextMultiModalProjector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.linear_1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaNextForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + vision_config = config.vision_config + # Instead of selecting in hidden_states[-2]. + # Instead compute only the n -2 + 1 layers and don't pool + if config.vision_feature_layer < 0: + vision_config.num_hidden_layers += config.vision_feature_layer + 1 + else: + vision_config.num_hidden_layers = config.vision_feature_layer + 1 + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.multi_modal_projector = LlavaNextMultiModalProjector( + prefix="multi_modal_projector", config=config, weights=weights + ) + + self.image_newline = weights.get_tensor("image_newline") + + self.vocab_size = config.text_config.vocab_size + self.config = config + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + self.language_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + def _merge_input_ids_with_image_features( self, + input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: torch.Tensor, - input_ids: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" mask = input_ids == self.config.image_token_index @@ -71,226 +164,120 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, + # Unused for this model + pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, ): + inputs_embeds = self.language_model.embed_tokens(input_ids) + if pixel_values is not None and len(pixel_values) > 0: + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings - if token_idx is not None: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) + image_features = self.vision_tower(pixel_values) - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state - logits = outputs[0] - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, - ) + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature else: - use_flash_attention = kwargs.get("use_flash_attention", False) - flash_attention_recompute = kwargs.get("flash_attention_recompute", False) - - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: - vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - batch_size, num_patches, num_channels, height, width = pixel_values.shape - reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) - image_features = self.vision_tower( - reshaped_pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - selected_image_feature = image_features.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError("The number of patches is not consistent with the image size.") - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx].tolist(), - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) - self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." ) - return model_inputs + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_features + ) + + hidden_states = self.language_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.language_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 87e192e0..e6125e29 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig from typing import Optional, Tuple, Type from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.models import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.cache_manager import ( get_cache_manager, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 7546e70b..4f35b0aa 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -2,8 +2,8 @@ import inspect import torch from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Type, TypeVar -from transformers import PreTrainedTokenizerBase +from typing import List, Tuple, Optional, TypeVar, Type +from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate @@ -22,10 +22,10 @@ class Model(ABC): device: torch.device, rank: int = 0, world_size: int = 1, - kwargs: dict = {}, + sliding_window: Optional[int] = None, speculate: Optional[int] = None, ): - self.model = model + self.model = model.eval() self.tokenizer = tokenizer # all_special_ids is not set correctly if the rust tokenizer is unpacked @@ -40,7 +40,8 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size - self.kwargs = kwargs + self.sliding_window = sliding_window if sliding_window != -1 else None + if speculate is None: speculate = get_speculate() self.speculate = speculate @@ -54,10 +55,14 @@ class Model(ABC): @property def info(self) -> InfoResponse: + if self.requires_padding and self.sliding_window is not None: + raise NotImplementedError("sliding_window is not implemented with padding") + return InfoResponse( requires_padding=self.requires_padding, dtype=str(self.dtype), device_type=self.device.type, + window_size=self.sliding_window, speculate=self.speculate, ) @@ -72,21 +77,28 @@ class Model(ABC): ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: raise NotImplementedError - def warmup(self, batch: B, max_total_tokens: int): + def warmup(self, batch: B) -> Optional[int]: self.generate_token(batch) + return None def decode_token( self, all_input_ids: List[int], prefix_offset: int = 0, read_offset: int = 0, + skip_special_tokens: bool = False, ) -> Tuple[str, int, int]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" # The prefix text is necessary only to defeat cleanup algorithms in the decode # which decide to add a space or not depending on the surrounding ids. - prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset], skip_special_tokens=False) - new_text = self.tokenizer.decode(all_input_ids[prefix_offset:], skip_special_tokens=False) + prefix_text = self.tokenizer.decode( + all_input_ids[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + ) + new_text = self.tokenizer.decode( + all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens + ) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index f07734c2..323e4324 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -1,5 +1,8 @@ -from typing import Optional, List import torch +import torch.distributed + +from typing import Optional, List +from transformers import AutoTokenizer, AutoModelForCausalLM from text_generation_server.models import CausalLM @@ -15,19 +18,29 @@ class SantaCoder(CausalLM): self, model_id: str, revision: Optional[str] = None, + quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - super().__init__( - model_id=model_id, + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, - speculator=speculator, - dtype=dtype, + padding_side="left", + truncation_side="left", trust_remote_code=trust_remote_code, ) - - self.tokenizer.add_special_tokens( + tokenizer.add_special_tokens( { "additional_special_tokens": [ EOD, @@ -39,7 +52,25 @@ class SantaCoder(CausalLM): "pad_token": EOD, } ) + with device: + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) diff --git a/server/text_generation_server/models/starcoder.py b/server/text_generation_server/models/starcoder.py deleted file mode 100644 index 98e7939a..00000000 --- a/server/text_generation_server/models/starcoder.py +++ /dev/null @@ -1,47 +0,0 @@ -from loguru import logger -import torch -from dataclasses import dataclass -import os -from typing import List, Optional, Type - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch - - -@dataclass -class StarCoderCausalLMBatch(CausalLMBatch): - past_key_values: Optional[List[torch.Tensor]] - - def detach_kv_cache(self): - past_keys = [] - past_values = [] - last_dim = int(self.past_key_values[0].size(dim=-1)/2) - for key_value in self.past_key_values: - past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) - past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) - del self.past_key_values - - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - self.past_key_values = [ - torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)] - - -class StarCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - - super(StarCoder, self).__init__( - model_id=model_id, - revision=revision, - dtype=dtype, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return StarCoderCausalLMBatch \ No newline at end of file diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 93cb179a..f0db89b2 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,93 +1,29 @@ import re import torch -import os -import time import math from PIL import Image from io import BytesIO import base64 -import numpy + from opentelemetry import trace -from loguru import logger from typing import Optional, Tuple, List, Type, Dict -import itertools -import tempfile -import copy -from text_generation_server.models import Model + from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution -from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import ( - CausalLMBatch, - CausalLMRequest, - remove_kv_cache_from_output, - biggest_single_chunk, +from text_generation_server.models.flash_mistral import ( + BaseFlashMistral, + FlashMistralBatch, ) - -from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, -) - -from transformers import AutoProcessor -import text_generation_server.habana_quantization_env as hq_env -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.cache_manager import ( get_cache_manager, ) -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - StoppingCriteria, - make_tokenizer_optional, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from optimum.habana.utils import get_hpu_memory_stats -from optimum.habana.checkpoint_utils import get_ds_injection_policy - -from transformers import ( - AutoTokenizer, - AutoModel, - PreTrainedTokenizerBase, - AutoConfig, -) -from optimum.habana.checkpoint_utils import ( - get_repo_root, - model_on_meta, - write_checkpoints_json, -) - -from text_generation_server.utils.speculate import get_speculate -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.debug import dbg_trace tracer = trace.get_tracer(__name__) IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") -BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) -MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) -MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) -PREFILL_WARMUP_BATCH_SIZE_LIST = [] -PREFILL_WARMUP_SEQLEN_LIST = [] -DECODE_WARMUP_BATCH_SIZE_LIST = [] -def round_up(warmup_list:list, num) : - i = 0 - for i in warmup_list: - if num <= i : - break - return i def split(string) -> List[Dict[str, str]]: parts = [] @@ -105,6 +41,30 @@ def split(string) -> List[Dict[str, str]]: return parts + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + def image_text_replacement(image_input, config, image_id) -> str: if config.model_type == "idefics2": # TODO technically depends on image splitting which is not implemented. @@ -117,7 +77,9 @@ def image_text_replacement(image_input, config, image_id) -> str: elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) + from loguru import logger + logger.info(f"Found {num_features} in image of resolution {height}x{width}") return "" * num_features elif config.model_type == "paligemma": @@ -163,7 +125,6 @@ def get_number_of_features(height: int, width: int, config) -> int: image_grid_pinpoints, image_size, ) - unpadded_features, newline_features = get_unpadded_features( height, width, npatches, num_patch_height, num_patch_width ) @@ -178,104 +139,31 @@ def load_data_uri(image_uri: str) -> Image.Image: image = Image.open(BytesIO(content)) return image -class VlmCausalLMBatch(CausalLMBatch): + +class VlmCausalLMBatch(FlashMistralBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod - def from_tokenized( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - batch_tokenized_inputs, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - - dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') - requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super(VlmCausalLMBatch, cls).concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch - max_input_length = max(r.data.truncate for r in requests) - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - # TODO: Add support for sparse batches - top_n_tokens = [r.top_n_tokens for r in pb.requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - tokenized_inputs = batch_tokenized_inputs - input_len = tokenized_inputs["input_ids"].shape[1] - - bucket_size = max_input_length - left_padding = max_input_length - input_len - if is_warmup is False: - if input_len < max_input_length : - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - # Allocate space for first token - if left_padding > 0: - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - - # New input length after left padding - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - - htorch.core.mark_step() - - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - ) + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + batch = super().filter(request_ids) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config, is_warmup): + def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): batch_inputs = [] image_inputs = [] max_truncation = 0 @@ -304,28 +192,16 @@ class VlmCausalLMBatch(CausalLMBatch): image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}") + batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) - if is_warmup is False: - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - missing_inputs = new_bs - len(requests) - dummy_images = [] - dummy_inputs = [] - if len(batch_inputs) > 0 and len(image_inputs) > 0: - dummy_inputs = [batch_inputs[0]] * missing_inputs - dummy_images = [image_inputs[0]] * missing_inputs - image_inputs += dummy_images - batch_inputs += dummy_inputs - batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation, - return_tensors="pt", - padding="longest", - return_token_type_ids=False, - ) + add_special_tokens=not config.model_type == "paligemma", + )["input_ids"] if image_inputs: image_input = image_inputs[0] new_image_inputs = { @@ -355,10 +231,9 @@ class VlmCausalLMBatch(CausalLMBatch): config, dtype: torch.dtype, device: torch.device, - is_warmup: bool = False, ) -> "VlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config, is_warmup + pb.requests, tokenizer, processor, config ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: @@ -379,836 +254,127 @@ class VlmCausalLMBatch(CausalLMBatch): batch.image_sizes = None return batch - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id, is_warmup) - - - - @classmethod - def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - if is_warmup is False : - new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) - batch_id = batches[0].batch_id - device = batches[0].input_ids.device - - input_lengths = [b.input_length for b in batches] - max_input_length = max(input_lengths) - offsets = [max_input_length - b.input_length for b in batches] - - cur_padding = [b.right_padding for b in batches] - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - - moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] - dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = (batches[dst_batch_idx].batch_size < new_bs) - - # TODO: Add support for changing max seq len, i.e. due to output length bucketing - # FIXME: max_seq_len for non optimized code - if len(batches) > 1: - scenario = 'CONCAT' - elif reshape: - scenario = 'RESHAPE' - elif cur_padding[dst_batch_idx] <= 0: - scenario = 'SHIFT' - offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] - max_input_length = max_input_length + offsets[dst_batch_idx] - else: - # Nothing to do - return batches[0] - - dbg_trace( - scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' - f' reqs:{[len(b) for b in batches]}' - f' offsets:{offsets}' - f' input_lengths:{input_lengths}' - f' cur_padding:{cur_padding}' - f' dst_batch:{dst_batch_idx}') - - grouped_requests = [[req for req in batch.requests] for batch in batches] - flat_requests = list(itertools.chain(*grouped_requests)) - - for i in range(len(batches)): - target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size - batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) - batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[i].split_kv_cache_if_needed(i == dst_batch_idx) - batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) - - top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - - parameters = [r.data.parameters for r in flat_requests] - # append the dummy parameters for dummy requests - batch_size = batches[dst_batch_idx].batch_size - parameters = pad_next_token_chooser_parameters(parameters, batch_size) - - # update past grammar states - fsm_grammar_states = [0] * batch_size - for batch in batches: - for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device, - batches[dst_batch_idx].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - input_ids = batches[dst_batch_idx].input_ids - attention_mask = batches[dst_batch_idx].attention_mask - position_ids = batches[dst_batch_idx].position_ids - past_key_values = batches[dst_batch_idx].past_key_values - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=flat_requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - -class VlmCausalLM(Model): - def __init__( - self, - model_class, - model_id: str, - *, - processor_class=AutoProcessor, - processor_kwargs=None, - batch_class=VlmCausalLMBatch, - revision, - dtype, - trust_remote_code: bool, - **kwargs, - ): - adapt_transformers_to_gaudi() - if processor_kwargs is None: - processor_kwargs = {} - self.processor = processor_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - **processor_kwargs, - ) - self.batch_class = batch_class - self.prev_bs = 0 - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - if world_size > 1: - model = self.get_deepspeed_model( - model_class, model_id, dtype, revision - ) - model = self.prepare_model_for_quantization(model) - else: - get_repo_root(model_id) - - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained( - model_id - ) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = model_class.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs - ) - model = self.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace( - "TORCH COMPILE", f'Torch compiling of model') - model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - - model = self.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in ["llama", "mistral", "llava_next"]: - kwargs["attn_softmax_bf16"] = True - kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": - kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": - kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - super(VlmCausalLM, self).__init__( - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - kwargs=kwargs, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - +class VlmCausalLM(BaseFlashMistral): @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return self.batch_class - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) - - def get_deepspeed_model( - self, - model_class, - model_id: str, - dtype: torch.dtype, - revision: Optional[str] = None - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = { - "revision": revision - } - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = model_class.from_config(config, torch_dtype=dtype) - else: - get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config) - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - write_checkpoints_json(model_id, local_rank, checkpoints_json) - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return { - 'type': rope_scaling, 'factor': float(rope_factor) - } - - def setup_quantization(self, model): - if hq_env.is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - def prepare_model_for_quantization(self, model): - if hq_env.is_quantization_enabled: - if model.config.model_type == "llama": - self.patch_scoped_linear_all_reduce(model) - import habana_quantization_toolkit - habana_quantization_toolkit.prep_model(model) - return model - - def finish_quantization_measurements(self, model): - if hq_env.is_quantization_enabled: - import habana_quantization_toolkit - habana_quantization_toolkit.finish_measurements(self.model) - return model - - def patch_scoped_linear_all_reduce(self, model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - self.patch_scoped_linear_all_reduce(module) - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) + return VlmCausalLMBatch def forward( - self, - input_ids, - attention_mask, - position_ids, - token_idx, - past_key_values: Optional[List[Tuple]] = None, - pixel_values: Optional[List[torch.Tensor]] = None, - image_sizes: Optional[List[Tuple[int, int]]] = None, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + self, batch: VlmCausalLMBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "token_idx": token_idx, - "pixel_values": pixel_values, - "image_sizes": image_sizes, - } + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices - hpu_kwargs = {} - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama" : - hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 + speculative_ids = batch.speculative_ids - if self.has_position_ids: - kwargs["position_ids"] = position_ids + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) - if bypass_hpu_graph != None: - hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length - kwargs.update(self.kwargs) - model_inputs = self.model.prepare_inputs_for_generation(**kwargs) - if past_key_values is not None: - return self.model.forward(**model_inputs, **hpu_kwargs) + input_ids = new_input_ids + position_ids = new_position_ids else: - outputs = self.model.forward(**model_inputs, **hpu_kwargs) - return outputs.logits, outputs.past_key_values + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: List[VlmCausalLMBatch], is_warmup: bool = False - ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding - token_idx = torch.tensor(token_idx_scalar).to(self.device) + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append({ - 'next_token_ids': next_token_ids, - 'next_token_logprobs': next_token_logprobs, - }) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append({ - 'req': req, - 'prev_req_idx': req.idx, - 'batch_id': batch_id, - 'seed': batch.next_token_chooser.seeds[req_idx], - 'do_sample': batch.next_token_chooser.do_sample[req_idx], - 'top_n_tokens': batch.top_n_tokens[req_idx], - 'top_token_ids': batch_top_token_ids[req_idx], - 'top_token_logprobs': batch_top_token_logprobs[req_idx], - 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], - }) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # Adjust lengths - batch.input_length += 1 - - # Update position_ids - if prefill: - batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - else: - batch.position_ids += 1 - # Update past key values - if prefill: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) + bs = input_ids.shape[0] + # Try to find an associated cuda graph + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) - - scenario = 'PREFILL' if prefill else 'GENERATE' - if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: - self.model.clear_cache() - self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - dbg_trace( - scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') - #assert batch.right_padding > 0, 'No more room for next token!' - - # Execute batch - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - batch.pixel_values, - batch.image_sizes, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, - ) - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) - batch.logits = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + cuda_graph = None + if cu_seqlen_prefill is not None or cuda_graph is None: + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits - htorch.core.mark_step() + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(-1) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - start_decode = time.time_ns() + # Replay the graph + cuda_graph["graph"].replay() - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() - prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data['req'] - i = req_data['prev_req_idx'] - prev_batch_id = req_data['batch_id'] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] - next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data['do_sample'] - seed = req_data['seed'] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data['top_n_tokens'] - top_token_ids = req_data['top_token_ids'] - top_token_logprobs = req_data['top_token_logprobs'] - grammar_state = req_data['grammar_state'] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: - next_token_text = '' - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0: new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def batch_from_pb(self, batch, is_warmup): - return VlmCausalLMBatch.from_pb_processor( - batch, - self.tokenizer, - self.processor, - self.model.config, - self.dtype, - self.device, - is_warmup + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None ) - - def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): - batch = copy.deepcopy(request.batches[0]) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_from_pb(batch, is_warmup) - - def warmup(self, request) -> None: - is_warmup = True - batches = [self.batch_from_pb(batch, is_warmup) for batch in request.batches] - - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batches[0]], is_warmup) - except: - raise RuntimeError( - f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - max_input_length = batches[0].input_ids.shape[1] - max_prefill_batch_size = batches[0].input_ids.shape[0] - PREFILL_WARMUP_BATCH_SIZE_LIST = [] - batch_size = 1 - while batch_size <= max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : - PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - - seq_len = BASE_IMAGE_TOKENS - PREFILL_WARMUP_SEQLEN_LIST = [] - i = 0 - while seq_len <= max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) - i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) - - #Prefill and decode warmup - DECODE_WARMUP_BATCH_SIZE_LIST = [] - prefill_batch = None - decode_batch = None - try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : - for seq_len in PREFILL_WARMUP_SEQLEN_LIST : - batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) - - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - - except: - raise RuntimeError( - f"Not enough memory to handle following prefill and decode warmup." - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats} " - ) - - max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - batch_size = max_prefill_batch_size * 2 - # Decode warmup with bigger batch_size - try: - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: - batches = [] - for i in range(int(batch_size/max_prefill_batch_size)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) - while batch_size <= max_decode_batch_size: - _, decode_batch, _ = self.generate_token(batches, is_warmup) - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - batches.clear() - - for i in range(int(batch_size/max_prefill_batch_size)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) - - batches.clear() - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: - max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 - batch_size = max_decode_batch_size - for i in range(int(max_decode_batch_size / 2)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup) - DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) - max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS - MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens - except : - raise RuntimeError( - f"Not enough memory to handle batch_size({batch_size}) decode warmup." - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"max_decode_batch_size is {max_decode_batch_size}" - f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats}" - ) - - return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a4dc4a4c..37c46032 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -1,8 +1,5 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import asyncio import os -import sys import torch import time import signal @@ -17,9 +14,6 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model -from text_generation_server.pb import generate_pb2_grpc, generate_pb2 -from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -32,7 +26,10 @@ try: except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. VLM_BATCH_TYPES = set() -from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION + +from text_generation_server.pb import generate_pb2_grpc, generate_pb2 +from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor +from text_generation_server.models.globals import set_model_id class SignalHandler: @@ -52,25 +49,24 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self, model: Model, cache: Cache, + quantize: Optional[str], server_urls: List[str], ): self.cache = cache self.model = model + self.quantize = quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU - # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul - # op not optimized issue. Will investigate further. - # if model.device.type == "hpu": - # Force inference mode for the lifetime of TextGenerationService - # self._inference_mode_raii_guard = torch._C._InferenceMode(True) - + if model.device.type == "cuda": + # Force inference mode for the lifetime of TextGenerationService + self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def Info(self, request, context): return self.model.info async def Health(self, request, context): - if self.model.device.type == "hpu": - torch.zeros((2, 2)).to("hpu") + if self.model.device.type == "cuda": + torch.zeros((2, 2)).cuda() return generate_pb2.HealthResponse() async def ServiceDiscovery(self, request, context): @@ -93,19 +89,41 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - def batch_from_pb(batch): - return self.model.batch_type.from_pb( - batch, self.model.tokenizer, self.model.dtype, self.model.device + if self.quantize == "gptq": + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.layers.gptq import ( + create_exllama_buffers, + set_device, + ) + + set_device(self.model.device) + create_exllama_buffers(request.max_prefill_tokens) + except ImportError: + pass + + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, ) - - if self.model.batch_type in VLM_BATCH_TYPES : - max_supported_total_tokens = self.model.warmup(request) - return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) else: - batches = [batch_from_pb(batch) for batch in request.batches] - self.model.warmup(batches) - return generate_pb2.WarmupResponse() + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + max_supported_total_tokens = self.model.warmup(batch) + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) async def Prefill(self, request, context): start = time.time_ns() @@ -124,8 +142,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - - generations, next_batch, timings = self.model.generate_token([batch]) + + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -151,13 +169,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) == 0: raise ValueError("All batches are empty") - generations, next_batch, timings = self.model.generate_token(batches) + if len(batches) > 1: + start_concat = time.time_ns() + batch = self.model.batch_type.concatenate(batches) + concat_ns = time.time_ns() - start_concat + else: + batch = batches[0] + concat_ns = None + + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, - concat_ns=None, # TODO: measure concat time + concat_ns=concat_ns, forward_ns=timings[0], decode_ns=timings[1], total_ns=time.time_ns() - start, @@ -168,62 +194,41 @@ def serve( model_id: str, revision: Optional[str], sharded: bool, + quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, uds_path: Path, ): - # Remove default handler - #logger.remove() - logger.add( - sys.stdout, - format="{message}", - filter="text_generation_server", - level="INFO", - serialize=False, - backtrace=True, - diagnose=False, - ) - async def serve_inner( model_id: str, revision: Optional[str], sharded: bool = False, + quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, trust_remote_code: bool = False, ): - if not is_driver_compatible(): - logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures") - unix_socket_template = "unix://{}-{}" - logger.info("Server:server_inner: sharded ={}".format(sharded)) - if sharded: - rank = int(os.environ["RANK"]) - logger.info("Server:server_inner: rank ={}".format(rank)) server_urls = [ - unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"])) + unix_socket_template.format(uds_path, rank) + for rank in range(int(os.environ["WORLD_SIZE"])) ] local_url = server_urls[int(os.environ["RANK"])] else: local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] - logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url)) - if dtype == "bfloat16" or None: - data_type = torch.bfloat16 - else: - data_type = torch.float - if revision == "None": - revision = None try: model = get_model( model_id, revision, + sharded, + quantize, speculate, - data_type, - trust_remote_code + dtype, + trust_remote_code, ) except Exception: logger.exception("Error when initializing model") @@ -236,7 +241,7 @@ def serve( ] ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( - TextGenerationService(model, Cache(), server_urls), server + TextGenerationService(model, Cache(), quantize, server_urls), server ) SERVICE_NAMES = ( generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, @@ -255,6 +260,6 @@ def serve( set_model_id(model_id) asyncio.run( serve_inner( - model_id, revision, sharded, speculate, dtype, trust_remote_code + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code ) ) diff --git a/server/text_generation_server/tgi_service.py b/server/text_generation_server/tgi_service.py deleted file mode 100644 index f88c8c8b..00000000 --- a/server/text_generation_server/tgi_service.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from pathlib import Path -from loguru import logger -import sys -from text_generation_server import server -import argparse - - -def main(args): - logger.info("TGIService: starting tgi service .... ") - logger.info( - "TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format( - args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path - ) - ) - server.serve( - model_id=args.model_id, - revision=args.revision, - sharded=args.sharded, - speculate=args.speculate, - dtype=args.dtype, - trust_remote_code=args.trust_remote_code, - uds_path=args.uds_path, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model_id", type=str) - parser.add_argument("--revision", type=str) - parser.add_argument("--sharded", type=bool) - parser.add_argument("--speculate", type=int, default=None) - parser.add_argument("--dtype", type=str) - parser.add_argument("--trust_remote_code", type=bool) - parser.add_argument("--uds_path", type=Path) - args = parser.parse_args() - main(args) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 565a7c3c..08ba808d 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -1,6 +1,3 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import text_generation_server.habana_quantization_env from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.weights import Weights @@ -21,9 +18,6 @@ from text_generation_server.utils.tokens import ( FinishReason, Sampling, Greedy, - make_tokenizer_optional, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, ) __all__ = [ diff --git a/server/text_generation_server/utils/debug.py b/server/text_generation_server/utils/debug.py deleted file mode 100644 index ef8d437b..00000000 --- a/server/text_generation_server/utils/debug.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import os -import glob -import time - -from optimum.habana.utils import to_gb_rounded -import habana_frameworks.torch as htorch - -START_TS = None -DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') -if 'GRAPH_VISUALIZATION' in os.environ: - for f in glob.glob('.graph_dumps/*'): - os.remove(f) - - -def count_hpu_graphs(): - return len(glob.glob('.graph_dumps/*PreGraph*')) - - -def dbg_trace(tag, txt): - global START_TS - if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0: - if START_TS is None: - START_TS = time.perf_counter() - time_offset = time.perf_counter() - START_TS - mem_stats = htorch.hpu.memory.memory_stats() - mem_used = to_gb_rounded(mem_stats['InUse']) - max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) - print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' - f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index d370a3d5..3625e6f2 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -44,12 +44,6 @@ class FakeGroup: def initialize_torch_distributed(): - import habana_frameworks.torch.core as htcore - - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - options = None if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL @@ -62,11 +56,6 @@ def initialize_torch_distributed(): options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) - elif torch.hpu.is_available(): - backend = "hccl" - n_hpus = torch.hpu.device_count() - if world_size > n_hpus: - raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") else: try: import oneccl_bindings_for_pytorch diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py deleted file mode 100644 index a832f755..00000000 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ /dev/null @@ -1,359 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_bwd, custom_fwd - -try: - import triton - import triton.language as tl - from . import custom_autotune - - # code based https://github.com/fpgaminer/GPTQ-triton - @custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, - ) - @triton.jit - def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # eventually avoid overflow - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - -except: - print("triton not installed.") - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) - return output - - -class QuantLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - return output - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 104fc2f0..6b915437 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -1,6 +1,5 @@ import math import torch -import habana_frameworks.torch.core as htcore from loguru import logger from typing import Dict, Union @@ -44,31 +43,37 @@ class StaticWarper: if typical_p is not None and typical_p < 1.0: self.warpers.append(TypicalLogitsWarper(mass=typical_p)) - self.hpu_graph = None + self.cuda_graph = None self.static_scores = None self.static_warped_scores = None self.static_next_logprob = None def __call__(self, scores): - if self.hpu_graph is None: - self.static_scores = scores.clone().contiguous() - self.static_warped_scores = scores.clone().contiguous() - self.static_next_logprob = scores.clone().contiguous() - self.hpu_graph = htcore.hpu.HPUGraph() + if torch.cuda.is_available(): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() - with htcore.hpu.graph(self.hpu_graph): - local_scores = self.static_scores - for warper in self.warpers: - local_scores = warper(None, local_scores) + with torch.cuda.graph(self.cuda_graph, pool=mempool): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) - self.static_warped_scores.copy_(local_scores) - # Compute logprobs - self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1)) + self.static_warped_scores = local_scores + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) - self.static_scores.copy_(scores) - self.hpu_graph.replay() + self.static_scores.copy_(scores) + self.cuda_graph.replay() - return self.static_warped_scores, self.static_next_logprob + return self.static_warped_scores, self.static_next_logprob + + # CPU branch + for warper in self.warpers: + scores = warper(None, scores) + return scores, torch.log_softmax(scores, -1) @lru_cache(10) @@ -78,7 +83,9 @@ def static_warper( top_p: Optional[float], typical_p: Optional[float], ) -> StaticWarper: - return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + return StaticWarper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): @@ -95,13 +102,17 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): self.penalty = penalty - self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) + self.penalty_tensor = torch.tensor( + penalty, dtype=dtype, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) + score = torch.where( + score < 0, score * self.penalty_tensor, score / self.penalty_tensor + ) scores.scatter_(1, input_ids, score) return scores @@ -159,11 +170,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): vocab_size = scores.size(1) # Calculate the frequency for each token so far - token_freq = torch.zeros( - batch_size, vocab_size, dtype=scores.dtype, device=scores.device - ) + token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) token_freq.scatter_add_( - 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device) + 1, input_ids, torch.ones_like(input_ids, dtype=torch.float) ) token_freq /= input_size @@ -190,9 +199,13 @@ class HeterogeneousTemperatureLogitsWarper: The value used to module the logits distribution. """ - def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): + def __init__( + self, temperature: List[float], dtype: torch.dtype, device: torch.device + ): self.temperature = temperature - self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) + self.temperature_tensor = torch.tensor( + temperature, dtype=dtype, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: scores.div_(self.temperature_tensor) @@ -231,7 +244,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): min_tokens_to_keep: int = 1, ): self.top_p = top_p - self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) + self.top_p_opposite = 1 - torch.tensor( + top_p, dtype=dtype, device=device + ).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @@ -248,7 +263,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) return warped_scores @@ -296,7 +313,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): disabled = [x == 0 for x in top_k] if any(disabled): - self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1) + self.top_k_disabled_mask = torch.tensor( + disabled, dtype=torch.bool, device=device + ).view(-1, 1) else: self.top_k_disabled_mask = None @@ -332,7 +351,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): self.max_top_k = max(self.top_k) if self.top_k_disabled_mask is not None: - self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None + self.top_k_disabled_mask = ( + self.top_k_disabled_mask[indices] if any(disabled) else None + ) return self return None @@ -398,11 +419,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): if self.disabled_mask is not None: last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) - sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) @@ -416,7 +441,9 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): self.mass_tensor = self.mass_tensor[indices] if self.disabled_mask is not None: - self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None + self.disabled_mask = ( + self.disabled_mask[indices] if any(disabled) else None + ) return self return None @@ -553,7 +580,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): mask = torch.full_like(logits, -math.inf) for i in range(logits.shape[0]): fsm = self.fsms[i] - if fsm is None: + if fsm_grammar_states[i] == -1 or fsm is None: continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) mask[i, allowed_tokens] = 0 diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 267ea068..22f86b60 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,3 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - import re from typing import List, Optional, Tuple, Set, Union @@ -22,22 +20,21 @@ from text_generation_server.utils.logits_process import ( ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor -from transformers.utils import to_py_obj class NextTokenChooser: def __init__( self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - frequency_penalty=0.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + watermark: bool = False, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + frequency_penalty: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + do_sample: bool = False, + seed: int = 0, + device: str = "cpu", tokenizer: Optional[PreTrainedTokenizerBase] = None, grammar: str = "", grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, @@ -70,7 +67,9 @@ class NextTokenChooser: or (typical_p is not None and typical_p < 1.0) ) if has_warpers: - self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + self.static_warper = static_warper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) else: self.static_warper = None @@ -249,12 +248,10 @@ class HeterogeneousNextTokenChooser: tokenizer: PreTrainedTokenizerBase, grammars: List[str], grammar_types: List[int], - fsm_grammar_states: List[int], - quantization_enabled: bool, + fsm_grammar_states=List[int], ): warpers = [] - # TODO: enable watermark with FP8 quantization self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -263,7 +260,7 @@ class HeterogeneousNextTokenChooser: if do_watermark } ) - if any(watermark) and not quantization_enabled + if any(watermark) else None ) @@ -435,18 +432,6 @@ class HeterogeneousNextTokenChooser: ) return self - def advance_grammar_single_with_past_state( - self, grammar_state_index: int, next_id: torch.Tensor, past_state: int - ): - if self.grammar_processor is not None: - next_id = next_id.item() - self.fsm_grammar_states[grammar_state_index] = ( - self.grammar_processor.advance_at_index( - next_id, past_state, grammar_state_index, - ) - ) - return self - def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices) @@ -497,7 +482,6 @@ class HeterogeneousNextTokenChooser: device: torch.device, tokenizer: PreTrainedTokenizerBase, fsm_grammar_states: Optional[List[int]] = None, - quantization_enabled: bool = False, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -517,37 +501,12 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), - quantization_enabled=quantization_enabled, ) -def pad_next_token_chooser_parameters( - parameters: List[generate_pb2.NextTokenChooserParameters], - expected_size: int, -) -> List[generate_pb2.NextTokenChooserParameters]: - # disable all logits processors to minimize padding overhead - empty_parameters = generate_pb2.NextTokenChooserParameters( - temperature=1.0, - top_k=0, - top_p=1.0, - typical_p=1.0, - do_sample=False, - seed=0, - repetition_penalty=1.0, - frequency_penalty=0.0, - watermark=False, - grammar="", - grammar_type=0, - ) - parameters.extend( - [empty_parameters] * (expected_size - len(parameters)) - ) - return parameters - - class Sampling: def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator("cpu") + self.generator = torch.Generator(device) self.generator.manual_seed(seed) self.seed = seed @@ -583,7 +542,7 @@ class HeterogeneousSampling: self.greedy = Greedy() def __call__(self, logits): - out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device) + out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: # Computing for all indices is faster than slicing torch.argmax(logits, -1, out=out) @@ -645,8 +604,9 @@ def batch_top_tokens( top_n_indices = (logprobs >= nth_highest).nonzero() _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) + k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max() # Take a new topk for these new max n values - top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) + top_k = torch.topk(logprobs, k=k, dim=1, sorted=True) top_n_ishes = top_n_ishes.tolist() top_indices = top_k.indices.tolist() @@ -684,50 +644,3 @@ def batch_top_tokens( batch_top_token_logprobs.append(row_top_token_logprobs) return batch_top_token_ids, batch_top_token_logprobs - - -def make_tokenizer_optional(tokenizer): - class _(type(tokenizer)): - def __call__( - self, - text, - return_tensors, - padding, - return_token_type_ids, - truncation, - max_length - ): - assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" - assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" - assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer" - assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer" - - def str_token_to_int(i): - if i == '?': - return tokenizer.pad_token_id - else: - return int(i) - all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] - for inner_text in text] - if padding == "longest": - max_length = max(len(tokens) for tokens in all_tokens) - return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]), - "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])} - - def decode( - self, - token_ids, - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, - **kwargs, - ) -> str: - return ','.join(str(i) for i in to_py_obj(token_ids)) - - import os - if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": - tokenizer.__class__ = _ - tokenizer.is_transparent = True - - -def is_tokenizer_transparent(tokenizer): - return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True diff --git a/server/text_generation_server/utils/version.py b/server/text_generation_server/utils/version.py deleted file mode 100644 index a72a9ea7..00000000 --- a/server/text_generation_server/utils/version.py +++ /dev/null @@ -1,12 +0,0 @@ -from optimum.habana.utils import get_driver_version -from packaging.version import Version - -MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0") - - -def is_driver_compatible(): - driver_version = get_driver_version() - if driver_version is not None: - if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION: - return False - return True \ No newline at end of file diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 7f4bf367..5d8f5312 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -34,17 +34,21 @@ class WatermarkLogitsProcessor(LogitsProcessor): # watermarking parameters self.gamma = gamma self.delta = delta - self.rng = torch.Generator(device="cpu") + self.rng = torch.Generator(device=device) self.hash_key = hash_key def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): if isinstance(input_ids, list): - assert len(input_ids) >= 1, "requires at least a 1 token prefix sequence to seed rng" + assert ( + len(input_ids) >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1] else: assert len(input_ids) == 1 input_ids = input_ids[0] - assert input_ids.shape[-1] >= 1, "requires at least a 1 token prefix sequence to seed rng" + assert ( + input_ids.shape[-1] >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) @@ -63,7 +67,9 @@ class WatermarkLogitsProcessor(LogitsProcessor): return greenlist_ids @staticmethod - def _calc_greenlist_mask(scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: + def _calc_greenlist_mask( + scores: torch.FloatTensor, greenlist_token_ids + ) -> torch.BoolTensor: green_tokens_mask = torch.zeros_like(scores) green_tokens_mask[-1, greenlist_token_ids] = 1 final_mask = green_tokens_mask.bool() @@ -76,9 +82,15 @@ class WatermarkLogitsProcessor(LogitsProcessor): scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias return scores - def __call__(self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor) -> torch.FloatTensor: - greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device) - green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=greenlist_ids) + def __call__( + self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor + ) -> torch.FloatTensor: + greenlist_ids = self._get_greenlist_ids( + input_ids, scores.shape[-1], scores.device + ) + green_tokens_mask = self._calc_greenlist_mask( + scores=scores, greenlist_token_ids=greenlist_ids + ) scores = self._bias_greenlist_logits( scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta From 479f1953baa58666c5a78cd5006d6e935c035382 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 24 May 2024 15:36:13 +0200 Subject: [PATCH 002/340] Fix seeded output. (#1949) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../test_chat_llama/test_flash_llama_simple.json | 10 +++++----- integration-tests/models/test_chat_llama.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json index 4cb548d2..8631c076 100644 --- a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json +++ b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", + "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1712874856, + "created": 1716553098, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native", + "system_fingerprint": "2.0.5-dev0-native", "usage": { "completion_tokens": 100, - "prompt_tokens": 60, - "total_tokens": 160 + "prompt_tokens": 62, + "total_tokens": 162 } } diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py index 11419a0e..10df6dbd 100644 --- a/integration-tests/models/test_chat_llama.py +++ b/integration-tests/models/test_chat_llama.py @@ -35,8 +35,9 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot): ], ) + print(repr(response.choices[0].message.content)) assert ( response.choices[0].message.content - == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" + == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas" ) assert response == response_snapshot From 742ef9b8e57d24a3a7778d94d1dbd7dc2321ddd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 24 May 2024 15:34:42 +0000 Subject: [PATCH 003/340] Fix (flash) Gemma prefix and enable tests --- integration-tests/models/test_flash_gemma.py | 5 +---- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- server/text_generation_server/models/flash_gemma.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index 2822b5e2..7ab43111 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_gemma_handle(launcher): - with launcher("gg-hf/gemma-2b", num_shard=1) as handle: + with launcher("google/gemma-2b", num_shard=1) as handle: yield handle @@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle): return flash_gemma_handle.client -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma(flash_gemma, response_snapshot): @@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot): assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_all_params(flash_gemma, response_snapshot): @@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index ac6fd0e6..cff4b5d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -423,7 +423,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): super().__init__() embed_norm = config.hidden_size**0.5 - if prefix is None: + if not prefix: prefix = "model" else: prefix = f"{prefix}.model" diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 53bfd064..358883e6 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -57,7 +57,7 @@ class FlashGemma(FlashCausalLM): weights._set_gptq_params(model_id, revision) # TODO hardcoded - prefix = "language_model" + prefix = "" model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) torch.distributed.barrier(group=self.process_group) From 1439b26cd4169a5e83c75c238ba0dff1619fe1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 27 May 2024 14:41:28 +0200 Subject: [PATCH 004/340] Fix GPTQ for models which do not have float16 at the default dtype (simpler) (#1953) # What does this PR do? Fix GPTQ for models which do not have float16 at the default dtype Before this change GPTQ models would not work if the model's default data type is not `float16`. For example, Gemma GPTQ models would fail because the default dtype of Gemma is `bfloat16`. There are two issues: If the default `dtype` is not `float16`, the quantizer's `float16` parameters get converted to that dtype. The kernels cannot deal with non-`float16` types. The same applies to inputs of quantized ops. This is resolved by setting the dtype of gptq/awq-quantized models to `float16`. Simpler version of #1951. **Draft:** just testing... ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../test_flash_gemma_gptq.json | 89 +++++ .../test_flash_gemma_gptq_all_params.json | 89 +++++ .../test_flash_gemma_gptq_load.json | 358 ++++++++++++++++++ .../models/test_flash_gemma_gptq.py | 62 +++ .../text_generation_server/models/__init__.py | 10 +- 5 files changed, 605 insertions(+), 3 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json create mode 100644 integration-tests/models/test_flash_gemma_gptq.py diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json new file mode 100644 index 00000000..760ebf94 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.640625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4296875, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8632812, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1328125, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76660156, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3837891, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9746094, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4189453, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.34375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8852539, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json new file mode 100644 index 00000000..7a168b2e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.65625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.3671875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 604, + "logprob": -0.36938477, + "special": false, + "text": " for" + }, + { + "id": 235248, + "logprob": -1.8046875, + "special": false, + "text": " " + }, + { + "id": 235274, + "logprob": -0.46240234, + "special": false, + "text": "1" + }, + { + "id": 235284, + "logprob": -1.7460938, + "special": false, + "text": "2" + }, + { + "id": 235265, + "logprob": -1.9443359, + "special": false, + "text": "." + }, + { + "id": 235284, + "logprob": -1.4550781, + "special": false, + "text": "2" + }, + { + "id": 235308, + "logprob": -1.0205078, + "special": false, + "text": "5" + }, + { + "id": 235290, + "logprob": -1.0283203, + "special": false, + "text": "-" + }, + { + "id": 235274, + "logprob": -1.2783203, + "special": false, + "text": "1" + }, + { + "id": 235284, + "logprob": 0.0, + "special": false, + "text": "2" + } + ], + "top_tokens": null + }, + "generated_text": "Test request for 12.25-12" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json new file mode 100644 index 00000000..bcb9b378 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.359375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4277344, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4394531, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8613281, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1523438, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76220703, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3642578, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -2.0175781, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4238281, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.328125, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8881836, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4238281, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.859375, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1445312, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.7631836, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3642578, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9960938, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4179688, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3359375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8847656, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.640625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.3671875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4257812, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8789062, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1367188, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76171875, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3515625, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9873047, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4169922, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3320312, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8930664, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.359375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4179688, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4492188, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8574219, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1445312, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.7519531, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3623047, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9707031, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4267578, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3359375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.88427734, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + } +] diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py new file mode 100644 index 00000000..7ed339f4 --- /dev/null +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_gptq_handle(launcher): + with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma_gptq(flash_gemma_gptq_handle): + await flash_gemma_gptq_handle.health(300) + return flash_gemma_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot): + response = await flash_gemma_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): + response = await flash_gemma_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq_load( + flash_gemma_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d4a325a9..92a20639 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -263,9 +263,13 @@ def get_model( trust_remote_code: bool, ) -> Model: if dtype is None: - # Keep it as default for now and let - # every model resolve their own default dtype. - dtype = None + if quantize in ["awq", "gptq"]: + # These quantizers only work with float16 params. + dtype = torch.float16 + else: + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": From 1213b6a817bf920791a3f88e0670fe02d8b8ee2d Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 27 May 2024 10:03:16 -0400 Subject: [PATCH 005/340] Processor config chat template (#1954) This PR loads the `processor_config` similar to the `tokenizer_config` and uses the processor_config's chat_template if the tokenizer_config does not include one. These changes enable chat with idefics2 --- router/src/infer.rs | 10 ++++++++-- router/src/lib.rs | 14 ++++++++++++++ router/src/main.rs | 19 +++++++++++++++++-- router/src/server.rs | 8 +++++--- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index eef42989..1447e756 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,7 +2,8 @@ use crate::validation::{Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token, + HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, + TextMessage, Token, }; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; @@ -67,6 +68,7 @@ impl Infer { speculate: u32, generation_health: Arc, tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); @@ -89,6 +91,7 @@ impl Infer { let chat_template = tokenizer_config .chat_template + .or(processor_config.chat_template) .and_then(|t| match t { ChatTemplateVersions::Single(template) => Some(template), ChatTemplateVersions::Multiple(templates) => templates @@ -98,7 +101,10 @@ impl Infer { }) .map(|t| { // .strip() is not supported in minijinja - let t = t.replace(".strip()", " | trim"); + // .capitalize() is not supported in minijinja but we can use | capitalize + let t = t + .replace(".strip()", " | trim") + .replace(".capitalize()", " | capitalize"); ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) }); diff --git a/router/src/lib.rs b/router/src/lib.rs index ba1d9acc..9b3283df 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -80,6 +80,20 @@ impl HubTokenizerConfig { } } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct HubProcessorConfig { + pub chat_template: Option, + pub image_seq_len: usize, + pub processor_class: Option, +} + +impl HubProcessorConfig { + pub fn from_file>(filename: P) -> Option { + let content = std::fs::read_to_string(filename).ok()?; + serde_json::from_str(&content).ok() + } +} + #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { diff --git a/router/src/main.rs b/router/src/main.rs index b11c4526..b526367c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,7 +14,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::config::Config; -use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; +use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use thiserror::Error; use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; @@ -206,11 +206,18 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer and model info - let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api { + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + processor_config_filename, + model_info, + ) = match api { Type::None => ( Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("processor_config.json")), None, ), Type::Api(api) => { @@ -226,6 +233,7 @@ async fn main() -> Result<(), RouterError> { }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let model_info = if let Some(model_info) = get_model_info(&api_repo).await { Some(model_info) @@ -237,6 +245,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_filename, config_filename, tokenizer_config_filename, + processor_config_filename, model_info, ) } @@ -250,6 +259,7 @@ async fn main() -> Result<(), RouterError> { repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), + repo.get("processor_config.json"), None, ) } @@ -286,6 +296,10 @@ async fn main() -> Result<(), RouterError> { HubTokenizerConfig::default() }); + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + tracing::info!("Using config {config:?}"); if tokenizer.is_none() { tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); @@ -397,6 +411,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + processor_config, messages_api_enabled, disable_grammar_support, max_client_batch_size, diff --git a/router/src/server.rs b/router/src/server.rs index e7570ded..64ec31eb 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,9 +5,9 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, - PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, - Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Infer, + Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, + TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1395,6 +1395,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, @@ -1495,6 +1496,7 @@ pub async fn run( shard_info.speculate, generation_health, tokenizer_config, + processor_config, ); // Duration buckets From e3d4483f9b6f8943433ae3291eab40bcacbecc85 Mon Sep 17 00:00:00 2001 From: Moritz Laurer <41862082+MoritzLaurer@users.noreply.github.com> Date: Mon, 27 May 2024 17:31:06 +0200 Subject: [PATCH 006/340] fix small typo and broken link (#1958) # What does this PR do? Fix a typo; fix a broken link; add one sentence in the guidance docs to make the word "grammar" less abstract ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @drbh --- docs/source/basic_tutorials/train_medusa.md | 2 +- docs/source/conceptual/guidance.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/basic_tutorials/train_medusa.md b/docs/source/basic_tutorials/train_medusa.md index 76cb6bed..ba2e43b7 100644 --- a/docs/source/basic_tutorials/train_medusa.md +++ b/docs/source/basic_tutorials/train_medusa.md @@ -1,6 +1,6 @@ # Train Medusa -This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation.md) for more information on how Medusa works and speculation in general. +This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general. ## What are the benefits of training a Medusa model? diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md index a566c4a6..1f3ff33a 100644 --- a/docs/source/conceptual/guidance.md +++ b/docs/source/conceptual/guidance.md @@ -2,11 +2,11 @@ ## What is Guidance? -Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. +Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. ## How is it used? -Guidance can be in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance: +Guidance can be implemented in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance: Technically, guidance can be used to generate: From cbd5d6710127e4efcb3b1351a62ae6ba4c38688f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 May 2024 14:52:17 +0200 Subject: [PATCH 007/340] Upgrade to Axum 0.7 and Hyper 1.0 (Breaking change: disabled ngrok tunneling). (#1959) - Axum upgraded to hyper 1.0 and most of the ecosystem switched so it's our time now - [ngrok-rust](https://github.com/ngrok/ngrok-rust/pull/137/files) hasn't yet, and hasn't for several months now, so let's disabled the feature for the time being. # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- Cargo.lock | 329 ++++++++++++++++++++--------- docs/source/conceptual/guidance.md | 3 +- router/Cargo.toml | 10 +- router/src/server.rs | 49 +---- 4 files changed, 243 insertions(+), 148 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2e75fe8f..8204ae83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,13 +233,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.3.4", "bitflags 1.3.2", "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", "itoa", "matchit", "memchr", @@ -251,13 +251,47 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "tokio", "tower", "tower-layer", "tower-service", ] +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core 0.4.3", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "hyper 1.3.1", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.3.4" @@ -267,8 +301,8 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "mime", "rustversion", "tower-layer", @@ -276,20 +310,41 @@ dependencies = [ ] [[package]] -name = "axum-tracing-opentelemetry" -version = "0.14.1" +name = "axum-core" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06985105829f176e9a3f113b1c71cc24e08f600ef0df4e70cd90d144f889e19f" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" dependencies = [ - "axum", + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 0.1.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-tracing-opentelemetry" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" +dependencies = [ + "axum 0.7.5", "futures-core", "futures-util", - "http", - "opentelemetry", + "http 1.1.0", + "opentelemetry 0.21.0", "pin-project-lite", "tower", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.22.0", "tracing-opentelemetry-instrumentation-sdk", ] @@ -756,33 +811,13 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys 0.3.7", -] - [[package]] name = "dirs" version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys 0.4.1", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", + "dirs-sys", ] [[package]] @@ -1139,7 +1174,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", "indexmap 2.2.6", "slab", "tokio", @@ -1202,7 +1237,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" dependencies = [ - "dirs 5.0.1", + "dirs", "futures", "indicatif", "log", @@ -1239,6 +1274,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -1246,15 +1292,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", "pin-project-lite", ] [[package]] -name = "http-range-header" -version = "0.3.1" +name = "http-body" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" +dependencies = [ + "bytes", + "futures-core", + "http 1.1.0", + "http-body 1.0.0", + "pin-project-lite", +] [[package]] name = "httparse" @@ -1279,8 +1342,8 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1292,13 +1355,32 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + [[package]] name = "hyper-timeout" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.28", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1835,7 +1917,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d4fa7ce7c4862db464a37b0b31d89bca874562f034bd7993895572783d02950" dependencies = [ "base64 0.21.7", - "hyper", + "hyper 0.14.28", "indexmap 1.9.3", "ipnet", "metrics", @@ -2003,12 +2085,12 @@ dependencies = [ "async-rustls", "async-trait", "awaitdrop", - "axum", + "axum 0.6.20", "base64 0.13.1", "bytes", "futures", "hostname", - "hyper", + "hyper 0.14.28", "muxado", "once_cell", "parking_lot", @@ -2287,19 +2369,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54" dependencies = [ "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_sdk 0.20.0", ] [[package]] -name = "opentelemetry-http" -version = "0.9.0" +name = "opentelemetry" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b" +checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ - "async-trait", - "bytes", - "http", - "opentelemetry_api", + "futures-core", + "futures-sink", + "indexmap 2.2.6", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", ] [[package]] @@ -2310,11 +2396,11 @@ checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275" dependencies = [ "async-trait", "futures-core", - "http", + "http 0.2.12", "opentelemetry-proto", "opentelemetry-semantic-conventions", "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_sdk 0.20.0", "prost 0.11.9", "thiserror", "tokio", @@ -2328,7 +2414,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" dependencies = [ "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_sdk 0.20.0", "prost 0.11.9", "tonic 0.9.2", ] @@ -2339,7 +2425,7 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269" dependencies = [ - "opentelemetry", + "opentelemetry 0.20.0", ] [[package]] @@ -2371,7 +2457,7 @@ dependencies = [ "futures-util", "once_cell", "opentelemetry_api", - "ordered-float", + "ordered-float 3.9.2", "percent-encoding", "rand", "regex", @@ -2381,6 +2467,26 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f16aec8a98a457a52664d69e0091bac3a0abd18ead9b641cb00202ba4e0efe4" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.21.0", + "ordered-float 4.2.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -2396,6 +2502,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" @@ -2918,9 +3033,9 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-tls", "ipnet", "js-sys", @@ -2934,7 +3049,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "system-configuration", "tokio", "tokio-native-tls", @@ -2987,9 +3102,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "6.8.1" +version = "8.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a36224c3276f8c4ebc8c20f158eca7ca4359c8db89991c4925132aaaf6702661" +checksum = "19549741604902eb99a7ed0ee177a0663ee1eda51a29f71401f166e47e77806a" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -3260,15 +3375,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shellexpand" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" -dependencies = [ - "dirs 4.0.0", -] - [[package]] name = "signal-hook" version = "0.3.17" @@ -3628,7 +3734,7 @@ dependencies = [ "ngrok", "nohash-hasher", "once_cell", - "opentelemetry", + "opentelemetry 0.20.0", "opentelemetry-otlp", "rand", "regex", @@ -3642,7 +3748,7 @@ dependencies = [ "tokio-stream", "tower-http", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", "utoipa-swagger-ui", @@ -3893,15 +3999,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" dependencies = [ "async-trait", - "axum", + "axum 0.6.20", "base64 0.21.7", "bytes", "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-timeout", "percent-encoding", "pin-project", @@ -3976,17 +4082,15 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "bitflags 2.5.0", "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", "pin-project-lite", "tower-layer", "tower-service", @@ -4066,8 +4170,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8" dependencies = [ "once_cell", - "opentelemetry", - "opentelemetry_sdk", + "opentelemetry 0.20.0", + "opentelemetry_sdk 0.20.0", "smallvec", "tracing", "tracing-core", @@ -4076,16 +4180,33 @@ dependencies = [ ] [[package]] -name = "tracing-opentelemetry-instrumentation-sdk" -version = "0.14.2" +name = "tracing-opentelemetry" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f523eba1b52bb854b804d43a039aafeaee5a623015065adbfef8016825319c15" +checksum = "c67ac25c5407e7b961fafc6f7e9aa5958fd297aada2d20fa2ae1737357e55596" dependencies = [ - "http", - "opentelemetry-http", - "opentelemetry_api", + "js-sys", + "once_cell", + "opentelemetry 0.21.0", + "opentelemetry_sdk 0.21.2", + "smallvec", "tracing", - "tracing-opentelemetry", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time", +] + +[[package]] +name = "tracing-opentelemetry-instrumentation-sdk" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c" +dependencies = [ + "http 1.1.0", + "opentelemetry 0.21.0", + "tracing", + "tracing-opentelemetry 0.22.0", ] [[package]] @@ -4242,9 +4363,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utoipa" -version = "3.5.0" +version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82b1bc5417102a73e8464c686eef947bdfb99fcdfc0a4f228e81afa9526470a" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ "indexmap 2.2.6", "serde", @@ -4439,6 +4560,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa30049b1c872b72c89866d458eae9f20380ab280ffd1b1e18df2d3e2d98cfe0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md index 1f3ff33a..ad1fc2ec 100644 --- a/docs/source/conceptual/guidance.md +++ b/docs/source/conceptual/guidance.md @@ -2,7 +2,8 @@ ## What is Guidance? -Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. + +Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. ## How is it used? diff --git a/router/Cargo.toml b/router/Cargo.toml index d164183e..fdfe1a5b 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,8 +16,8 @@ path = "src/main.rs" [dependencies] async-stream = "0.3.5" -axum = { version = "0.6.20", features = ["json"] } -axum-tracing-opentelemetry = "0.14.1" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" @@ -36,12 +36,12 @@ thiserror = "1.0.48" tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.14" -tower-http = { version = "0.4.4", features = ["cors"] } +tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } -utoipa = { version = "3.5.0", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" } diff --git a/router/src/server.rs b/router/src/server.rs index 64ec31eb..f11812e2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1392,8 +1392,8 @@ pub async fn run( addr: SocketAddr, allow_origin: Option, ngrok: bool, - ngrok_authtoken: Option, - ngrok_edge: Option, + _ngrok_authtoken: Option, + _ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, processor_config: HubProcessorConfig, messages_api_enabled: bool, @@ -1666,46 +1666,9 @@ pub async fn run( if ngrok { #[cfg(feature = "ngrok")] { - use ngrok::config::TunnelBuilder; - - let _ = addr; - - let authtoken = - ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); - - let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling"); - - let tunnel = ngrok::Session::builder() - .authtoken(authtoken) - .connect() - .await - .unwrap() - .labeled_tunnel() - .label("edge", edge); - - let listener = tunnel.listen().await.unwrap(); - - // Run prom metrics and health locally too - tokio::spawn( - axum::Server::bind(&addr) - .serve( - Router::new() - .route("/health", get(health)) - .route("/metrics", get(metrics)) - .layer(Extension(health_ext)) - .layer(Extension(prom_handle)) - .into_make_service(), - ) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()), - ); + panic!("ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable."); // Run server - axum::Server::builder(listener) - .serve(app.into_make_service()) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) - .await?; } #[cfg(not(feature = "ngrok"))] { @@ -1718,9 +1681,9 @@ pub async fn run( } } else { // Run server - axum::Server::bind(&addr) - .serve(app.into_make_service()) - // Wait until all requests are finished to shut down + + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await?; } From 9a1475d816071dace5c75ae39d7f2af5090c2023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 28 May 2024 07:25:14 +0000 Subject: [PATCH 008/340] Fix (non-container) pytest stdout buffering-related lock-up Two issues: 1. When one of the stdout/stderr pipe buffers of a process started with `subprocess.Popen` is full, the process can get blocked until the buffer is drained. 2. Calling `Popen.wait` can deadlock when called before draining the pipe buffers (if they are full). This avoids the issue altogether by giving the child process a temporary file to write to. --- integration-tests/conftest.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index ae3f977b..d81b8736 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -7,9 +7,10 @@ import os import docker import json import math +import shutil +import tempfile import time import random -import re from docker.errors import NotFound from typing import Optional, List, Dict @@ -347,19 +348,22 @@ def launcher(event_loop): if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" - with subprocess.Popen( - args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env - ) as process: - yield ProcessLauncherHandle(process, port) + with tempfile.TemporaryFile("w+") as tmp: + # We'll output stdout/stderr to a temporary file. Using a pipe + # cause the process to block until stdout is read. + with subprocess.Popen( + args, + stdout=tmp, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port) - process.terminate() - process.wait(60) + process.terminate() + process.wait(60) - launcher_output = process.stdout.read().decode("utf-8") - print(launcher_output, file=sys.stderr) - - process.stdout.close() - process.stderr.close() + tmp.seek(0) + shutil.copyfileobj(tmp, sys.stderr) if not use_flash_attention: del env["USE_FLASH_ATTENTION"] From 2b204f04795699bac89b5cd94311d2c95727a738 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 May 2024 16:55:36 +0200 Subject: [PATCH 009/340] Fixing the text part from tokenizer endpoint. (#1967) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/src/server.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/router/src/server.rs b/router/src/server.rs index f11812e2..eb7ba2a0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1333,7 +1333,8 @@ async fn tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = input.chars().skip(start).take(stop - start).collect(); + let text: String = + String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); SimpleToken { id, text, From 4dca35fc62710285022a8a74aadbc6477195fe0b Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 29 May 2024 12:42:11 -0400 Subject: [PATCH 010/340] feat: adjust attn weight loading logic (#1975) This PR updates `load_attention` to prefer loading specific attention based on the model type. Additionally there were two cases where `TensorParallelColumnLinear.load_multi` was called and this reduces it to a single path --- .../custom_modeling/flash_llama_modeling.py | 48 ++++++++----------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6e23aa2b..f722bf73 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -49,37 +49,31 @@ if SYSTEM == "rocm": def load_attention(config, prefix, weights): bias = config.attention_bias - if config.num_attention_heads != config.num_key_value_heads: - return TensorParallelColumnLinear.load_multi( + + # if specific model type, load the correct attention + if config.model_type == "phi3": + return TensorParallelColumnLinear.load_qkv( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, + prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, ) - else: - if config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.W_pack", - weights=weights, - bias=bias, - ) - elif config.model_type == "phi3": - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.qkv_proj", - weights=weights, - bias=bias, - ) - else: - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=bias, - ) + elif config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=bias, + ) + + # otherwise, load the default attention based on the number of heads + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) class FlashLlamaAttention(torch.nn.Module): From 628d6a13dab6ca086d7e1b288cc4df6624486f90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 28 May 2024 09:51:31 +0000 Subject: [PATCH 011/340] Add support for exl2 quantization Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM. --- docs/source/basic_tutorials/launcher.md | 1 + docs/source/conceptual/guidance.md | 1 - integration-tests/conftest.py | 22 +- .../test_flash_llama_exl2.json | 84 +++++ .../test_flash_llama_exl2_all_params.json | 84 +++++ .../test_flash_llama_exl2_load.json | 338 ++++++++++++++++++ .../models/test_flash_llama_exl2.py | 73 ++++ launcher/src/main.rs | 12 + server/text_generation_server/cli.py | 1 + server/text_generation_server/layers/exl2.py | 23 ++ .../layers/gptq/__init__.py | 22 ++ .../layers/gptq/exllama.py | 26 +- .../layers/gptq/exllamav2.py.rej | 253 +++++++++++++ .../text_generation_server/layers/linear.py | 49 +-- .../layers/tensor_parallel.py | 70 +++- .../text_generation_server/models/__init__.py | 11 +- .../custom_modeling/flash_dbrx_modeling.py | 11 +- .../custom_modeling/flash_llama_modeling.py | 2 +- .../custom_modeling/flash_mistral_modeling.py | 47 +-- .../flash_santacoder_modeling.py | 12 +- .../models/flash_llama.py | 2 +- server/text_generation_server/server.py | 2 +- .../text_generation_server/utils/weights.py | 105 +++++- 23 files changed, 1143 insertions(+), 108 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json create mode 100644 integration-tests/models/test_flash_llama_exl2.py create mode 100644 server/text_generation_server/layers/exl2.py create mode 100644 server/text_generation_server/layers/gptq/exllamav2.py.rej diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 1e5b6fd2..c00d2e1a 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -62,6 +62,7 @@ Options: Possible values: - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from + - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md index ad1fc2ec..3059e3de 100644 --- a/docs/source/conceptual/guidance.md +++ b/docs/source/conceptual/guidance.md @@ -2,7 +2,6 @@ ## What is Guidance? - Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. ## How is it used? diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index d81b8736..2ef85da6 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") class ResponseComparator(JSONSnapshotExtension): rtol = 0.2 + ignore_logprob = False def serialize( self, @@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension): return ( token.id == other.id and token.text == other.text - and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) + and ( + self.ignore_logprob + or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) + ) and token.special == other.special ) @@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension): prefill_token.id == other.id and prefill_token.text == other.text and ( - math.isclose( - prefill_token.logprob, other.logprob, rel_tol=self.rtol + self.ignore_logprob + or math.isclose( + prefill_token.logprob, + other.logprob, + rel_tol=self.rtol, ) if prefill_token.logprob is not None else prefill_token.logprob == other.logprob @@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator): rtol = 0.75 +class IgnoreLogProbResponseComparator(ResponseComparator): + ignore_logprob = True + + class LauncherHandle: def __init__(self, port: int): self.client = AsyncClient(f"http://localhost:{port}") @@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot): return snapshot.use_extension(GenerousResponseComparator) +@pytest.fixture +def ignore_logprob_response_snapshot(snapshot): + return snapshot.use_extension(IgnoreLogProbResponseComparator) + + @pytest.fixture(scope="module") def event_loop(): loop = asyncio.get_event_loop() diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json new file mode 100644 index 00000000..f6e4bb90 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.4375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9316406, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.5136719, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.7783203, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2314453, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -2.0019531, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.5009766, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.057434082, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4912109, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2636719, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.4042969, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json new file mode 100644 index 00000000..6b38e709 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.453125, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.9980469, + "special": false, + "text": "." + }, + { + "id": 578, + "logprob": -0.15795898, + "special": false, + "text": " The" + }, + { + "id": 3622, + "logprob": -1.0458984, + "special": false, + "text": " server" + }, + { + "id": 31680, + "logprob": -1.3623047, + "special": false, + "text": " responds" + }, + { + "id": 449, + "logprob": 0.0, + "special": false, + "text": " with" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 330, + "logprob": -0.5678711, + "special": false, + "text": " \"" + }, + { + "id": 1049, + "logprob": -0.12322998, + "special": false, + "text": "200" + }, + { + "id": 10619, + "logprob": 0.0, + "special": false, + "text": " OK" + }, + { + "id": 1, + "logprob": 0.0, + "special": false, + "text": "\"" + } + ], + "top_tokens": null + }, + "generated_text": "Test request. The server responds with a \"200 OK\"" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json new file mode 100644 index 00000000..ed369a87 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.453125, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9785156, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.4941406, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.79345703, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2324219, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.9794922, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4892578, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.058258057, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4892578, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2783203, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3945312, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.40625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9433594, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.4726562, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.8022461, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2509766, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.984375, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4677734, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.059173584, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4990234, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2822266, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3867188, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.421875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9511719, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.46875, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.77490234, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2558594, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.984375, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4990234, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.059143066, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4941406, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2578125, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3964844, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.4140625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9101562, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.5039062, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.8076172, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2236328, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.9853516, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4892578, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.056671143, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.5107422, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2597656, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.4042969, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + } +] diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py new file mode 100644 index 00000000..18319f60 --- /dev/null +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -0,0 +1,73 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_exl2_handle(launcher): + with launcher( + "turboderp/Llama-3-8B-Instruct-exl2", + revision="2.5bpw", + # Set max input length to avoid OOM due to extremely large + # scratch buffer. + max_input_length=1024, + num_shard=1, + quantize="exl2", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_exl2(flash_llama_exl2_handle): + await flash_llama_exl2_handle.health(300) + return flash_llama_exl2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): + response = await flash_llama_exl2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2_all_params( + flash_llama_exl2, ignore_logprob_response_snapshot +): + response = await flash_llama_exl2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert ( + response.generated_text == 'Test request. The server responds with a "200 OK"' + ) + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2_load( + flash_llama_exl2, generate_load, ignore_logprob_response_snapshot +): + responses = await generate_load( + flash_llama_exl2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == ignore_logprob_response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 88b0db57..6c4e85d6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -57,6 +57,10 @@ enum Quantization { /// Should be a drop-in replacement to bitsandbytes with much better performance. /// Kernels are from Eetq, + /// Variable bit quantization. Requires a specific EXL2 quantized model: + /// . Requires exllama2 kernels and does + /// not support tensor parallelism (num_shard > 1). + Exl2, /// 4 bit quantization. Requires a specific GTPQ quantized model: . /// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// triton kernel (wider support) when it's not. @@ -97,6 +101,9 @@ impl std::fmt::Display for Quantization { Quantization::BitsandbytesFP4 => { write!(f, "bitsandbytes-fp4") } + Quantization::Exl2 => { + write!(f, "exl2") + } Quantization::Gptq => { write!(f, "gptq") } @@ -1475,6 +1482,11 @@ fn main() -> Result<(), LauncherError> { let num_shard = find_num_shards(args.sharded, args.num_shard)?; if num_shard > 1 { + if matches!(args.quantize, Some(Quantization::Exl2)) { + return Err(LauncherError::ArgumentValidation( + "Sharding is currently not supported with `exl2` quantization".into(), + )); + } tracing::info!("Sharding model on {num_shard} processes"); } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index ad623ccc..16375ecd 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" eetq = "eetq" + exl2 = "exl2" fp8 = "fp8" diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py new file mode 100644 index 00000000..f6cb729e --- /dev/null +++ b/server/text_generation_server/layers/exl2.py @@ -0,0 +1,23 @@ +import torch +from dataclasses import dataclass + + +@dataclass +class Exl2Weight: + """ + Exllama2 exl2 quantized weights. + """ + + q_weight: torch.Tensor + q_scale: torch.Tensor + q_invperm: torch.Tensor + q_scale_max: torch.Tensor + q_groups: torch.Tensor + + def __post_init__(self): + self.q_scale_max /= 256 + self.q_invperm = self.q_invperm.short() + + @property + def device(self) -> torch.device: + return self.q_weight.device diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 1c46f493..1172775f 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,9 +1,31 @@ +from dataclasses import dataclass import os +from typing import Optional import torch from text_generation_server.utils.import_utils import ( SYSTEM, ) + +@dataclass +class GPTQWeight: + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: Optional[torch.Tensor] + bits: int + groupsize: int + use_exllama: bool + + def __post_init__(self): + if self.scales.dtype == torch.float: + self.scales = self.scales.half() + + @property + def device(self) -> torch.device: + return self.qweight.device + + try: major, _minor = torch.cuda.get_device_capability() except Exception: diff --git a/server/text_generation_server/layers/gptq/exllama.py b/server/text_generation_server/layers/gptq/exllama.py index 32f817db..4875af38 100644 --- a/server/text_generation_server/layers/gptq/exllama.py +++ b/server/text_generation_server/layers/gptq/exllama.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.weights import GPTQWeight import torch from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params @@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int): class Ex4bitLinear(torch.nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + def __init__(self, weight: GPTQWeight, bias): super().__init__() global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE - assert bits == 4 + assert weight.bits == 4 - self.device = qweight.device - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx.cpu() if g_idx is not None else None + self.device = weight.qweight.device + self.qweight = weight.qweight + self.qzeros = weight.qzeros + self.scales = weight.scales + self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None self.bias = bias if bias is not None else None if self.g_idx is not None and ( (self.g_idx == 0).all() or torch.equal( - g_idx.cpu(), + weight.g_idx.cpu(), torch.tensor( - [i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32 + [i // weight.groupsize for i in range(weight.g_idx.shape[0])], + dtype=torch.int32, ), ) ): @@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module): self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index ) - self.height = qweight.shape[0] * 8 - self.width = qweight.shape[1] + self.height = weight.qweight.shape[0] * 8 + self.width = weight.qweight.shape[1] # Infer groupsize from height of qzeros self.groupsize = None @@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module): self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) if self.groupsize is not None: - assert groupsize == self.groupsize + assert weight.groupsize == self.groupsize # Handle act-order matrix if self.g_idx is not None: diff --git a/server/text_generation_server/layers/gptq/exllamav2.py.rej b/server/text_generation_server/layers/gptq/exllamav2.py.rej new file mode 100644 index 00000000..ff13179b --- /dev/null +++ b/server/text_generation_server/layers/gptq/exllamav2.py.rej @@ -0,0 +1,253 @@ +diff a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py (rejected hunks) +@@ -1,10 +1,15 @@ + # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 + ++from dataclasses import dataclass ++from typing import Optional + import torch + import torch.nn as nn + + from loguru import logger + ++from text_generation_server.layers.exl2 import Exl2Weight ++from text_generation_server.layers.gptq import GPTQWeight ++ + try: + from exllamav2_kernels import make_q_matrix, gemm_half_q_half + except ImportError: +@@ -15,6 +20,15 @@ except ImportError: + none_tensor = torch.empty((1, 1), device="meta") + + ++@dataclass ++class _ExtraTensors: ++ """Additional generated quantizer tensors.""" ++ ++ q_group_map: Optional[torch.Tensor] = None ++ q_invperm: Optional[torch.Tensor] = None ++ q_perm: Optional[torch.Tensor] = None ++ ++ + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): + """Matrix multiplication, returns x @ q4""" + output_shape = x.shape[:-1] + (q4_width,) +@@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): + return output.view(output_shape) + + +-# Group map needed for irregular group sizes +- +- +-def make_group_map(q_groups, num_qrows): +- ++def make_group_map(q_groups: torch.Tensor, num_qrows: int): + gr = q_groups.tolist() + group_map = [] + num_groups = len(gr) // 2 +@@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows): + # Create Q matrix + + +-def ext_make_q_matrix(w: dict, temp_dq, key: str = None): ++def ext_make_q_matrix( ++ w: Exl2Weight | GPTQWeight, ++ extra: _ExtraTensors, ++ temp_dq, ++ key: Optional[str] = None, ++): + """ + Create Q matrix + """ + # EXL2 +- # won't work as the moment because the tensors are not the same. +- if "q_weight" in w: +- w["q_scale_max"] /= 256 +- w["q_perm"] = w["q_perm"].short() +- w["q_invperm"] = w["q_invperm"].short() +- +- if "q_group_map" not in w: +- w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) ++ if isinstance(w, Exl2Weight): ++ extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) ++ extra.q_perm = torch.argsort(w.q_invperm).short() + + return make_q_matrix( +- w["q_weight"], +- w["q_perm"], +- w["q_invperm"], +- w["q_scale"], +- w["q_scale_max"], +- w["q_groups"], +- w["q_group_map"], ++ w.q_weight, ++ extra.q_perm, ++ w.q_invperm, ++ w.q_scale, ++ w.q_scale_max, ++ w.q_groups, ++ extra.q_group_map, + none_tensor, + none_tensor, + none_tensor, + temp_dq, + ) + # GPTQ +- elif "qweight" in w: +- if w["scales"].dtype == torch.float: +- w["scales"] = w["scales"].half() ++ elif isinstance(w, GPTQWeight): ++ if w.scales.dtype == torch.float: ++ w.scales = w.scales.half() + + # GPTQ with g_idx (act_order) +- if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): +- w["q_perm"] = torch.empty( +- (w["qweight"].shape[0] * 8,), ++ if w.g_idx is not None and not (w.g_idx == 0).all().item(): ++ extra.q_perm = torch.empty( ++ (w.qweight.shape[0] * 8,), + dtype=torch.short, +- device=w["qweight"].device, ++ device=w.qweight.device, + ) +- w["q_invperm"] = torch.empty_like(w["q_perm"]) ++ extra.q_invperm = torch.empty_like(extra.q_perm) + # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. + return make_q_matrix( +- w["qweight"], +- w["q_perm"], +- w["q_invperm"], ++ w.qweight, ++ extra.q_perm, ++ extra.q_invperm, + none_tensor, + none_tensor, + none_tensor, + none_tensor, +- w["qzeros"], +- w["scales"], +- w["g_idx"].cpu(), ++ w.qzeros, ++ w.scales, ++ w.g_idx.cpu(), + temp_dq, + ) + # GPTQ without g_idx + else: + return make_q_matrix( +- w["qweight"], ++ w.qweight, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, +- w["qzeros"], +- w["scales"], ++ w.qzeros, ++ w.scales, + none_tensor, + temp_dq, + ) +@@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): + + + DEVICE = None +-FIXED_BYTES = 0 + LAYERS = [] + + +@@ -134,8 +143,13 @@ def set_device(device): + + + def create_exllama_buffers(max_total_tokens: int): +- global FIXED_BYTES, LAYERS, DEVICE +- temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) ++ global LAYERS, DEVICE ++ ++ # Find the size of the scratch space. ++ scratch_bytes = max( ++ layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS ++ ) ++ temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) + + for layer in LAYERS: + layer.post_init(temp_dq) +@@ -146,49 +160,48 @@ class QuantLinear(nn.Module): + + """Linear layer implementation with per-group 4-bit quantization of the weights""" + +- # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): +- def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): ++ def __init__( ++ self, ++ weight: Exl2Weight | GPTQWeight, ++ bias: torch.Tensor, ++ ): + super().__init__() +- if bits != 4: +- raise ValueError( +- f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." +- ) ++ + self.q_handle = None +- self.q_tensors = None +- self.bits = bits +- self.maxq = 2**self.bits - 1 +- self.infeatures = qweight.shape[0] // self.bits * 32 +- self.outfeatures = qweight.shape[1] ++ self.q_tensors = weight ++ self.extra_tensors = _ExtraTensors() ++ ++ if isinstance(weight, Exl2Weight): ++ self.infeatures = weight.q_invperm.shape[0] ++ self.outfeatures = weight.q_weight.shape[1] ++ elif isinstance(weight, GPTQWeight): ++ if weight.bits != 4: ++ raise ValueError( ++ f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." ++ ) ++ ++ self.infeatures = weight.qweight.shape[0] // weight.bits * 32 ++ self.outfeatures = weight.qweight.shape[1] ++ + self.padding = -self.outfeatures % 32 + self.outfeatures = self.outfeatures + self.padding + +- self.device = qweight.device +- self.qweight = qweight +- self.qzeros = qzeros +- self.scales = scales +- self.g_idx = g_idx ++ self.device = weight.device + self.bias = bias if bias is not None else None +- self.group_size = groupsize + +- global FIXED_BYTES, LAYERS +- FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) ++ global LAYERS + LAYERS.append(self) + + def post_init(self, temp_dq): +- assert self.qweight.device.type == "cuda" +- assert self.qweight.device.index is not None +- self.q_tensors = { +- "qweight": self.qweight, +- "qzeros": self.qzeros, +- "scales": self.scales, +- "g_idx": self.g_idx, +- } ++ device = self.q_tensors.device ++ assert device.type == "cuda" ++ assert device.index is not None + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) + + # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, + # and `Memory access fault by GPU node-2` will EAT you. + self.temp_dq = temp_dq +- self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) ++ self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) + + def forward(self, x, force_cuda=False): + output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 5bd6aa95..570aa75c 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,6 +1,9 @@ +from typing import Optional import torch from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight if SYSTEM == "rocm": try: @@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize): bias, quant_type="nf4", ) + elif quantize == "exl2": + if not isinstance(weight, Exl2Weight): + raise NotImplementedError( + f"The passed weight is not `exl2` compatible, loader needs to be updated." + ) + + from text_generation_server.layers.gptq import ExllamaQuantLinear + + linear = ExllamaQuantLinear(weight, bias) + elif quantize == "gptq": - try: - qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight - except Exception: + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - if use_exllama: + if weight.use_exllama: try: from text_generation_server.layers.gptq import ( ExllamaQuantLinear, @@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize): f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" ) - linear = ExllamaQuantLinear( - qweight, qzeros, scales, g_idx, bias, bits, groupsize - ) + linear = ExllamaQuantLinear(weight, bias) else: from text_generation_server.layers.gptq.quant_linear import QuantLinear linear = QuantLinear( - qweight, - qzeros, - scales, - g_idx, + weight.qweight, + weight.qzeros, + weight.scales, + weight.g_idx, bias, - bits, - groupsize, + weight.bits, + weight.groupsize, ) elif quantize == "awq": - try: - qweight, qzeros, scales, _, bits, groupsize, _ = weight - except Exception: + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) @@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize): from text_generation_server.layers.awq.quantize.qmodule import WQLinear linear = WQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, + w_bit=weight.bits, + group_size=weight.groupsize, + qweight=weight.qweight, + qzeros=weight.qzeros, + scales=weight.scales, bias=bias is not None, ) except ImportError: diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 34b9c51e..afaaa1b8 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -1,7 +1,27 @@ import torch from torch.nn import functional as F -from typing import List +from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear +from text_generation_server.layers.exl2 import Exl2Weight + + +class LayerConcat(torch.nn.Module): + """ + Apply multiple layers to the input and concatenate their + outputs. + """ + + def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1): + """ + `dim` is the dimension along which layer outputs are concatenated. + """ + super().__init__() + self.layers = layers + self.dim = dim + + def forward(self, x: torch.Tensor): + outputs = [layer(x) for layer in self.layers] + return torch.cat(outputs, self.dim) class SuperLayer(torch.nn.Module): @@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): - if weights.process_group.size() > 1: + if config.quantize == "exl2": + try: + # If the piece and LM head embeddings are shared, we have + # non-quantized weights... + weight = weights.get_tensor(f"{prefix}.weight") + except: + # ...otherwise they are quantized. + weight = weights.get_weights_col(prefix, config.quantize) + should_gather = weights.process_group.size() > 1 + elif weights.process_group.size() > 1: try: weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True @@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer): # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) if config.quantize in ["gptq", "awq", "eetq"]: quantize = None + # See above, exl2 LM head can be quantized or not. + elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): + quantize = None else: quantize = config.quantize + return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, @@ -108,22 +141,35 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - return cls.load_multi(config, [prefix], weights, bias, dim=0) - - @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) - + weight = weights.get_weights_col(prefix, config.quantize) if bias: - b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] - bias = torch.cat(b, dim=dim) + bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None linear = get_linear(weight, bias, config.quantize) return cls(linear) + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + if config.quantize == "exl2": + linears = [] + for prefix in prefixes: + weight = weights.get_weights_col(prefix, config.quantize) + b = weights.get_tensor(f"{prefix}.bias") if bias else None + linears.append(get_linear(weight, b, config.quantize)) + linear = LayerConcat(linears) + else: + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + class TensorParallelRowLinear(SuperLayer): def __init__(self, linear, process_group): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 92a20639..d086f87b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -263,7 +263,7 @@ def get_model( trust_remote_code: bool, ) -> Model: if dtype is None: - if quantize in ["awq", "gptq"]: + if quantize in ["awq", "exl2", "gptq"]: # These quantizers only work with float16 params. dtype = torch.float16 else: @@ -402,12 +402,17 @@ def get_model( quantization_config = config_dict.get("quantization_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq"}: + if method in {"gptq", "awq", "exl2"}: logger.info(f"Auto selecting quantization method {method}") quantize = method else: logger.info(f"Unknown quantization method {method}") + if quantize == "exl2" and sharded: + raise RuntimeError( + "Sharding is currently not supported with `exl2` quantization" + ) + if model_type == MAMBA: return Mamba( model_id, @@ -881,6 +886,8 @@ def get_model( raise NotImplementedError("4bit quantization is not supported for AutoModel") elif quantize == "eetq": raise NotImplementedError("Eetq quantization is not supported for AutoModel") + elif quantize == "exl2": + raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d652b67..56bfb9d0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,6 +21,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger +from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights): else: g_idx = None - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) else: qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") q = qkv_slice[q_start:q_stop] diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f722bf73..fa3a78f8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.suffix", + prefix=suffix if not prefix else f"{prefix}.{suffix}", weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ef3777da..65043dee 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig): ) -def load_attention(config, prefix, weights): - if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) - else: - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - - -def _load_gqa(config, prefix: str, weights): - assert config.hidden_size % config.num_attention_heads == 0 - assert config.num_attention_heads % weights.process_group.size() == 0 - - weight = weights.get_multi_weights_col( - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, - dim=0, - ) - - if config.quantize not in ["gptq", "awq"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) - - class MistralAttention(torch.nn.Module): def __init__( self, @@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) self.o_proj = TensorParallelRowLinear.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d2f6d9af..cfa4243f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,6 +5,7 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple +from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -90,8 +91,15 @@ def _load_multi_mqa_gptq( from text_generation_server.layers.gptq import HAS_EXLLAMA - use_exllama = HAS_EXLLAMA - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=HAS_EXLLAMA, + ) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 9a7dfaee..c5cbd2b8 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "exl2"]: weights._set_gptq_params(model_id, revision) prefix = "" diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 37c46032..4118b3f6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.quantize == "gptq": + if self.quantize in {"exl2", "gptq"}: try: # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6af7d3fb..710ea680 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,11 +1,14 @@ +from dataclasses import dataclass, field import os from pathlib import Path -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Set, Tuple, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.log import log_once @@ -76,8 +79,9 @@ class Weights: f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert - # u4 which are disguised as int32 - if tensor.dtype not in [torch.int32, torch.int64]: + # u4 which are disguised as int32. Exl2 uses int16 + # as well. + if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) @@ -102,8 +106,8 @@ class Weights: else: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert - # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + # u4 which are disguised as int32. exl2 uses int16. + if tensor.dtype not in (torch.int16, torch.int32): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -183,7 +187,15 @@ class Weights: else: g_idx = None - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=False, + ) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] @@ -207,8 +219,34 @@ class Weights: weight = weight.to(dtype=self.dtype) return weight + def get_weights_col(self, prefix: str, quantize: str): + if quantize == "exl2": + try: + q_weight = self.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = self.get_tensor(f"{prefix}.q_scale") + q_invperm = self.get_tensor(f"{prefix}.q_invperm") + q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") + q_groups = self.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + return self.get_multi_weights_col([prefix], quantize, 0) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize in ["gptq", "awq"]: + if quantize == "exl2": + raise ValueError("get_multi_weights_col is not supported for exl2") + elif quantize in ["gptq", "awq"]: try: qweight = torch.cat( [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -259,7 +297,15 @@ class Weights: else: g_idx = None - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -282,7 +328,28 @@ class Weights: return tensor def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "gptq": + if quantize == "exl2": + try: + q_weight = self.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = self.get_tensor(f"{prefix}.q_scale") + q_invperm = self.get_tensor(f"{prefix}.q_invperm") + q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") + q_groups = self.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + elif quantize == "gptq": use_exllama = True bits, groupsize, desc_act, quant_method = self._get_gptq_params() @@ -363,7 +430,15 @@ class Weights: // groupsize ).to(dtype=torch.int32) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) elif quantize == "awq": bits, groupsize, _, _ = self._get_gptq_params() @@ -379,7 +454,15 @@ class Weights: g_idx = None use_exllama = False - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight From f6c5e078d5332bfc8ab169dac642895dd77f5884 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 30 May 2024 07:10:10 +0000 Subject: [PATCH 012/340] Gemma GPTQ checks: skip logprob checks This test fails somewhat regularly due to non-determinism and this test is primarily to verify that we are loading a model which doesn't have `float16` as the default dtype correctly. --- integration-tests/models/test_flash_gemma_gptq.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py index 7ed339f4..8ac5f5a1 100644 --- a/integration-tests/models/test_flash_gemma_gptq.py +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -15,18 +15,20 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot): +async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): response = await flash_gemma_gptq.generate( "Test request", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 - assert response == response_snapshot + assert response == ignore_logprob_response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): +async def test_flash_gemma_gptq_all_params( + flash_gemma_gptq, ignore_logprob_response_snapshot +): response = await flash_gemma_gptq.generate( "Test request", max_new_tokens=10, @@ -44,13 +46,13 @@ async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response == response_snapshot + assert response == ignore_logprob_response_snapshot @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_load( - flash_gemma_gptq, generate_load, response_snapshot + flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot ): responses = await generate_load( flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 @@ -59,4 +61,4 @@ async def test_flash_gemma_gptq_load( assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) - assert responses == response_snapshot + assert responses == ignore_logprob_response_snapshot From dd2d46d9d141b7ed7dd02dc13bdd4b78d421a0bf Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 31 May 2024 07:03:24 -0700 Subject: [PATCH 013/340] Update documentation version to 2.0.4 (#1980) As per title cc @Narsil --- docs/README.md | 10 ++++++++++ docs/source/basic_tutorials/gated_model_access.md | 2 +- docs/source/installation_amd.md | 2 +- docs/source/installation_nvidia.md | 2 +- docs/source/quicktour.md | 4 ++-- 5 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 docs/README.md diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..fb2ff198 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,10 @@ +Documentation available at: https://huggingface.co/docs/text-generation-inference + +## Release + +When making a release, please update the latest version in the documentation with: +``` +export OLD_VERSION="2\.0\.3" +export NEW_VERSION="2\.0\.4" +find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \; +``` diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index 970afa0e..b49c59c9 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -19,6 +19,6 @@ docker run --gpus all \ --shm-size 1g \ -e HUGGING_FACE_HUB_TOKEN=$token \ -p 8080:80 \ - -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \ + -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ --model-id $model ``` diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 636d301c..d70953ae 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.0.3-rocm \ + ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \ --model-id $model ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 62e1a3d6..9077f7fd 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.0.3 \ + ghcr.io/huggingface/text-generation-inference:2.0.4 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 6137c6f6..b84de85d 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.0.3 \ + ghcr.io/huggingface/text-generation-inference:2.0.4 \ --model-id $model ``` @@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:2.0.3 --help +docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help ``` From bdc676f65c36de486e3ca0fe777a7f4532b23e2c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 31 May 2024 17:57:01 +0200 Subject: [PATCH 014/340] Purely refactors paged/attention into `layers/attention` and make hardware differences more obvious with 1 file per hardware. (#1986) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/src/infer.rs | 1 - server/Makefile-flash-att-v2 | 4 +- .../layers/attention/__init__.py | 13 + .../layers/attention/cuda.py | 245 +++++++++++++++ .../attention}/flash_attn_triton.py | 0 .../layers/attention/rocm.py | 295 ++++++++++++++++++ .../layers/attention/xpu.py | 76 +++++ .../text_generation_server/models/__init__.py | 39 +-- .../custom_modeling/flash_cohere_modeling.py | 12 +- .../custom_modeling/flash_dbrx_modeling.py | 14 +- .../custom_modeling/flash_gemma_modeling.py | 14 +- .../custom_modeling/flash_gpt2_modeling.py | 12 +- .../custom_modeling/flash_llama_modeling.py | 14 +- .../custom_modeling/flash_mistral_modeling.py | 12 +- .../custom_modeling/flash_mixtral_modeling.py | 12 +- .../custom_modeling/flash_neox_modeling.py | 15 +- .../custom_modeling/flash_phi_modeling.py | 14 +- .../custom_modeling/flash_qwen2_modeling.py | 12 +- .../custom_modeling/flash_rw_modeling.py | 20 +- .../flash_santacoder_modeling.py | 12 +- .../flash_starcoder2_modeling.py | 12 +- .../text_generation_server/models/globals.py | 1 + .../utils/flash_attn.py | 293 ----------------- .../utils/import_utils.py | 2 + .../utils/paged_attention.py | 137 -------- 25 files changed, 754 insertions(+), 527 deletions(-) create mode 100644 server/text_generation_server/layers/attention/__init__.py create mode 100644 server/text_generation_server/layers/attention/cuda.py rename server/text_generation_server/{utils => layers/attention}/flash_attn_triton.py (100%) create mode 100644 server/text_generation_server/layers/attention/rocm.py create mode 100644 server/text_generation_server/layers/attention/xpu.py delete mode 100644 server/text_generation_server/utils/flash_attn.py delete mode 100644 server/text_generation_server/utils/paged_attention.py diff --git a/router/src/infer.rs b/router/src/infer.rs index 1447e756..0410de7d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -70,7 +70,6 @@ impl Infer { tokenizer_config: HubTokenizerConfig, processor_config: HubProcessorConfig, ) -> Self { - // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 36ef576a..bbff0090 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,11 +1,11 @@ -flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 +flash_att_v2_commit_cuda := v2.5.8 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash-attention-v2-cuda: # Clone flash attention pip install -U packaging ninja --no-cache-dir - git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 build-flash-attention-v2-cuda: flash-attention-v2-cuda cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py new file mode 100644 index 00000000..e6cb4edf --- /dev/null +++ b/server/text_generation_server/layers/attention/__init__.py @@ -0,0 +1,13 @@ +from text_generation_server.utils.import_utils import SYSTEM +import os + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") +if SYSTEM == "cuda": + from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "rocm": + from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "xpu": + from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py new file mode 100644 index 00000000..583337bd --- /dev/null +++ b/server/text_generation_server/layers/attention/cuda.py @@ -0,0 +1,245 @@ +import torch +from text_generation_server.utils.import_utils import SYSTEM + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +_PARTITION_SIZE = 512 + +try: + from vllm._C import cache_ops + from vllm._C import ops +except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + + +def paged_attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + from vllm._C import ops + + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + + +try: + import flash_attn_2_cuda + + V2 = True +except ImportError: + try: + import flash_attn_cuda + + V2 = False + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + + +SUPPORTS_WINDOWING = V2 +if V2: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + None, + None, + None, + max_s, + max_s, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + False, + None, + ) + +else: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) diff --git a/server/text_generation_server/utils/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py similarity index 100% rename from server/text_generation_server/utils/flash_attn_triton.py rename to server/text_generation_server/layers/attention/flash_attn_triton.py diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py new file mode 100644 index 00000000..2d3601c8 --- /dev/null +++ b/server/text_generation_server/layers/attention/rocm.py @@ -0,0 +1,295 @@ +import os +import torch +from text_generation_server.utils.import_utils import SYSTEM +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +_PARTITION_SIZE = 512 + +use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} +ENGINE = "triton" if use_triton else "ck" + +try: + from vllm._C import cache_ops + from vllm._C import ops +except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + + +def paged_attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + from vllm._C import ops + + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + + +if ENGINE != "triton": + try: + import flash_attn_2_cuda + + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError: + try: + import flash_attn_cuda + + ENGINE = "v1" + logger.info("ROCm: using Flash Attention 1") + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + + +SUPPORTS_WINDOWING = ENGINE != "v1" +if ENGINE == "ck": + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + if window_size_left != -1: + raise ValueError( + f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + None, + None, + None, + max_s, + max_s, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + False, + None, + ) + +elif ENGINE == "triton": + from .flash_attn_triton import triton_attention + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left != -1: + raise ValueError( + f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + output, _ = triton_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + causal, + softmax_scale, + ) + return output + +else: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py new file mode 100644 index 00000000..d9a096f9 --- /dev/null +++ b/server/text_generation_server/layers/attention/xpu.py @@ -0,0 +1,76 @@ +import intel_extension_for_pytorch as ipex +import torch + +SUPPORTS_WINDOWING = False + + +def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, +): + if window_size_left != -1: + raise ValueError( + f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + return ipex.llm.functional.varlen_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + + +def paged_attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + query = query.contiguous() + block_size = value_cache.shape[3] + return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d086f87b..dbe49039 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -80,15 +80,11 @@ try: from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx - from text_generation_server.utils.flash_attn import ( - HAS_FLASH_ATTN_V2_CUDA, - HAS_FLASH_ATTN_V2_ROCM, - ) + from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") + SUPPORTS_WINDOWING = False FLASH_ATTENTION = False - HAS_FLASH_ATTN_V2_CUDA = False - HAS_FLASH_ATTN_V2_ROCM = False if FLASH_ATTENTION: __all__.append(FlashGPT2) @@ -262,6 +258,7 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, ) -> Model: + global FLASH_ATTENTION if dtype is None: if quantize in ["awq", "exl2", "gptq"]: # These quantizers only work with float16 params. @@ -412,6 +409,12 @@ def get_model( raise RuntimeError( "Sharding is currently not supported with `exl2` quantization" ) + sliding_window = config_dict.get("sliding_window", -1) + if sliding_window != -1 and not SUPPORTS_WINDOWING: + logger.warning( + f"Flash attention is available, but doesn't support windowing which is required by model {model_id}" + ) + FLASH_ATTENTION = False if model_type == MAMBA: return Mamba( @@ -699,11 +702,7 @@ def get_model( if model_type == MISTRAL: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if FLASH_ATTENTION: return FlashMistral( model_id, revision, @@ -726,11 +725,7 @@ def get_model( if model_type == MIXTRAL: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if FLASH_ATTENTION: return FlashMixtral( model_id, revision, @@ -753,11 +748,7 @@ def get_model( if model_type == STARCODER2: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if FLASH_ATTENTION: return FlashStarcoder2( model_id, revision, @@ -781,11 +772,7 @@ def get_model( if model_type == QWEN2: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index bd8b8016..31109bc9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -25,7 +25,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -281,7 +285,7 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -289,7 +293,7 @@ class FlashCohereAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, key, value, @@ -300,7 +304,7 @@ class FlashCohereAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 56bfb9d0..497956e3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -27,7 +27,11 @@ from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": from vllm.model_executor.layers.fused_moe import fused_moe -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( FastLinear, TensorParallelRowLinear, @@ -424,9 +428,7 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -434,7 +436,7 @@ class DbrxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -445,7 +447,7 @@ class DbrxAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index cff4b5d5..89ca8b5b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -26,7 +26,11 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -221,9 +225,7 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -231,7 +233,7 @@ class FlashGemmaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -243,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d2599f7a..52a7c283 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -25,7 +25,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -213,7 +217,7 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -221,7 +225,7 @@ class FlashGPT2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, key, value, @@ -232,7 +236,7 @@ class FlashGPT2Attention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index fa3a78f8..c0fa09fd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,7 +28,11 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -145,9 +149,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -155,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -166,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 65043dee..77a8a384 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -186,7 +190,7 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -196,7 +200,7 @@ class MistralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -208,7 +212,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index be2d6c45..37cd6f3b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -33,7 +33,11 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from loguru import logger -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( FastLinear, TensorParallelRowLinear, @@ -265,7 +269,7 @@ class MixtralAttention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -275,7 +279,7 @@ class MixtralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -287,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d45cab2e..59e7bf8b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,8 +27,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -146,9 +149,7 @@ class FlashNeoxAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) - paged_attention.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(qkv[:, 0]) @@ -156,7 +157,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], @@ -167,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f2efb538..af3206dd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -6,7 +6,11 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -185,16 +189,14 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -205,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 3a6d2db5..2b035c2e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -5,7 +5,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -142,7 +146,7 @@ class Qwen2Attention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -152,7 +156,7 @@ class Qwen2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -164,7 +168,7 @@ class Qwen2Attention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fa463a19..d489c3ba 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -15,7 +15,11 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.utils import flash_attn, paged_attention +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) def load_row(config, prefix: str, weights, bias: bool): @@ -194,9 +198,7 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -204,7 +206,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -215,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], @@ -313,7 +315,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - paged_attention.reshape_and_cache( + reshape_and_cache( kv[:, :, 0].contiguous(), kv[:, :, 1].contiguous(), kv_cache[0], @@ -327,7 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), @@ -338,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index cfa4243f..c8397000 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,7 +6,11 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -276,7 +280,7 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - paged_attention.reshape_and_cache( + reshape_and_cache( key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -286,7 +290,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), @@ -297,7 +301,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 3e2ce4f9..37486e9d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -26,7 +26,11 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -229,7 +233,7 @@ class Starcoder2Attention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -239,7 +243,7 @@ class Starcoder2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -251,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index e8a11958..11a9f030 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,5 +1,6 @@ import torch import os +from loguru import logger MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py deleted file mode 100644 index 4f5cf10b..00000000 --- a/server/text_generation_server/utils/flash_attn.py +++ /dev/null @@ -1,293 +0,0 @@ -import os -import torch - -from loguru import logger -import math - -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM != "xpu": - from text_generation_server.utils.flash_attn_triton import triton_attention - -if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") -HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2_CUDA = False -HAS_FLASH_ATTN_V2_ROCM = False -ROCM_USE_FLASH_ATTN_V2_CK = False -ROCM_USE_FLASH_ATTN_V2_TRITON = False - - -if SYSTEM in {"cuda", "rocm"}: - if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - is_sm94 = major == 9 and minor == 4 - - if SYSTEM == "rocm": - if ( - os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" - or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1" - ): - ROCM_USE_FLASH_ATTN_V2_TRITON = True - logger.info("ROCm: using Flash Attention 2 Triton implementation.") - else: - ROCM_USE_FLASH_ATTN_V2_CK = True - logger.info( - "ROCm: using Flash Attention 2 Composable Kernel implementation." - ) - - try: - try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if SYSTEM == "cuda" and not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94): - raise ImportError( - f"AMD GPU with compute capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda" - HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm" - except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif SYSTEM == "rocm": - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - -if SYSTEM == "xpu": - import intel_extension_for_pytorch as ipex - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - if window_size_left != -1: - raise ValueError( - f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) - return ipex.llm.functional.varlen_attention( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - None, - ) - -elif HAS_FLASH_ATTN_V2_CUDA: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - None, - None, - None, - max_s, - max_s, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - False, - None, - ) - -elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) - - # RoCm flash API does not take the window_size_left and window_size_right arguments. - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) - -elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, - ): - output, _ = triton_attention( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - causal, - softmax_scale, - ) - return output - -elif HAS_FLASH_ATTN: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - return flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - 0, - None, - ) - -else: - raise NotImplementedError("flash attention is not installed") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 40e57646..d79e36c2 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,4 +1,5 @@ import torch +from loguru import logger def is_xpu_available(): @@ -48,3 +49,4 @@ else: empty_cache = noop synchronize = noop get_free_memory = noop +logger.info(f"Detected system {SYSTEM}") diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py deleted file mode 100644 index 6cc30e6d..00000000 --- a/server/text_generation_server/utils/paged_attention.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -from text_generation_server.utils.import_utils import SYSTEM - -_PARTITION_SIZE = 512 - -if SYSTEM == "xpu": - import intel_extension_for_pytorch as ipex -else: - try: - from vllm._C import cache_ops - from vllm._C import ops - except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if SYSTEM == "xpu": - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - - -def attention( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - if SYSTEM == "xpu": - query = query.contiguous() - return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - ) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) From d1473fab70975c29beb59645d61c109a0b337f7f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 31 May 2024 18:01:43 +0200 Subject: [PATCH 015/340] Fixing exl2 scratch buffer. (#1990) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../layers/gptq/exllamav2.py.rej | 253 ------------------ 1 file changed, 253 deletions(-) delete mode 100644 server/text_generation_server/layers/gptq/exllamav2.py.rej diff --git a/server/text_generation_server/layers/gptq/exllamav2.py.rej b/server/text_generation_server/layers/gptq/exllamav2.py.rej deleted file mode 100644 index ff13179b..00000000 --- a/server/text_generation_server/layers/gptq/exllamav2.py.rej +++ /dev/null @@ -1,253 +0,0 @@ -diff a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py (rejected hunks) -@@ -1,10 +1,15 @@ - # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 - -+from dataclasses import dataclass -+from typing import Optional - import torch - import torch.nn as nn - - from loguru import logger - -+from text_generation_server.layers.exl2 import Exl2Weight -+from text_generation_server.layers.gptq import GPTQWeight -+ - try: - from exllamav2_kernels import make_q_matrix, gemm_half_q_half - except ImportError: -@@ -15,6 +20,15 @@ except ImportError: - none_tensor = torch.empty((1, 1), device="meta") - - -+@dataclass -+class _ExtraTensors: -+ """Additional generated quantizer tensors.""" -+ -+ q_group_map: Optional[torch.Tensor] = None -+ q_invperm: Optional[torch.Tensor] = None -+ q_perm: Optional[torch.Tensor] = None -+ -+ - def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): - """Matrix multiplication, returns x @ q4""" - output_shape = x.shape[:-1] + (q4_width,) -@@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): - return output.view(output_shape) - - --# Group map needed for irregular group sizes -- -- --def make_group_map(q_groups, num_qrows): -- -+def make_group_map(q_groups: torch.Tensor, num_qrows: int): - gr = q_groups.tolist() - group_map = [] - num_groups = len(gr) // 2 -@@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows): - # Create Q matrix - - --def ext_make_q_matrix(w: dict, temp_dq, key: str = None): -+def ext_make_q_matrix( -+ w: Exl2Weight | GPTQWeight, -+ extra: _ExtraTensors, -+ temp_dq, -+ key: Optional[str] = None, -+): - """ - Create Q matrix - """ - # EXL2 -- # won't work as the moment because the tensors are not the same. -- if "q_weight" in w: -- w["q_scale_max"] /= 256 -- w["q_perm"] = w["q_perm"].short() -- w["q_invperm"] = w["q_invperm"].short() -- -- if "q_group_map" not in w: -- w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) -+ if isinstance(w, Exl2Weight): -+ extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) -+ extra.q_perm = torch.argsort(w.q_invperm).short() - - return make_q_matrix( -- w["q_weight"], -- w["q_perm"], -- w["q_invperm"], -- w["q_scale"], -- w["q_scale_max"], -- w["q_groups"], -- w["q_group_map"], -+ w.q_weight, -+ extra.q_perm, -+ w.q_invperm, -+ w.q_scale, -+ w.q_scale_max, -+ w.q_groups, -+ extra.q_group_map, - none_tensor, - none_tensor, - none_tensor, - temp_dq, - ) - # GPTQ -- elif "qweight" in w: -- if w["scales"].dtype == torch.float: -- w["scales"] = w["scales"].half() -+ elif isinstance(w, GPTQWeight): -+ if w.scales.dtype == torch.float: -+ w.scales = w.scales.half() - - # GPTQ with g_idx (act_order) -- if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): -- w["q_perm"] = torch.empty( -- (w["qweight"].shape[0] * 8,), -+ if w.g_idx is not None and not (w.g_idx == 0).all().item(): -+ extra.q_perm = torch.empty( -+ (w.qweight.shape[0] * 8,), - dtype=torch.short, -- device=w["qweight"].device, -+ device=w.qweight.device, - ) -- w["q_invperm"] = torch.empty_like(w["q_perm"]) -+ extra.q_invperm = torch.empty_like(extra.q_perm) - # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return make_q_matrix( -- w["qweight"], -- w["q_perm"], -- w["q_invperm"], -+ w.qweight, -+ extra.q_perm, -+ extra.q_invperm, - none_tensor, - none_tensor, - none_tensor, - none_tensor, -- w["qzeros"], -- w["scales"], -- w["g_idx"].cpu(), -+ w.qzeros, -+ w.scales, -+ w.g_idx.cpu(), - temp_dq, - ) - # GPTQ without g_idx - else: - return make_q_matrix( -- w["qweight"], -+ w.qweight, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, -- w["qzeros"], -- w["scales"], -+ w.qzeros, -+ w.scales, - none_tensor, - temp_dq, - ) -@@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): - - - DEVICE = None --FIXED_BYTES = 0 - LAYERS = [] - - -@@ -134,8 +143,13 @@ def set_device(device): - - - def create_exllama_buffers(max_total_tokens: int): -- global FIXED_BYTES, LAYERS, DEVICE -- temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) -+ global LAYERS, DEVICE -+ -+ # Find the size of the scratch space. -+ scratch_bytes = max( -+ layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS -+ ) -+ temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) - - for layer in LAYERS: - layer.post_init(temp_dq) -@@ -146,49 +160,48 @@ class QuantLinear(nn.Module): - - """Linear layer implementation with per-group 4-bit quantization of the weights""" - -- # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): -- def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): -+ def __init__( -+ self, -+ weight: Exl2Weight | GPTQWeight, -+ bias: torch.Tensor, -+ ): - super().__init__() -- if bits != 4: -- raise ValueError( -- f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." -- ) -+ - self.q_handle = None -- self.q_tensors = None -- self.bits = bits -- self.maxq = 2**self.bits - 1 -- self.infeatures = qweight.shape[0] // self.bits * 32 -- self.outfeatures = qweight.shape[1] -+ self.q_tensors = weight -+ self.extra_tensors = _ExtraTensors() -+ -+ if isinstance(weight, Exl2Weight): -+ self.infeatures = weight.q_invperm.shape[0] -+ self.outfeatures = weight.q_weight.shape[1] -+ elif isinstance(weight, GPTQWeight): -+ if weight.bits != 4: -+ raise ValueError( -+ f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." -+ ) -+ -+ self.infeatures = weight.qweight.shape[0] // weight.bits * 32 -+ self.outfeatures = weight.qweight.shape[1] -+ - self.padding = -self.outfeatures % 32 - self.outfeatures = self.outfeatures + self.padding - -- self.device = qweight.device -- self.qweight = qweight -- self.qzeros = qzeros -- self.scales = scales -- self.g_idx = g_idx -+ self.device = weight.device - self.bias = bias if bias is not None else None -- self.group_size = groupsize - -- global FIXED_BYTES, LAYERS -- FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) -+ global LAYERS - LAYERS.append(self) - - def post_init(self, temp_dq): -- assert self.qweight.device.type == "cuda" -- assert self.qweight.device.index is not None -- self.q_tensors = { -- "qweight": self.qweight, -- "qzeros": self.qzeros, -- "scales": self.scales, -- "g_idx": self.g_idx, -- } -+ device = self.q_tensors.device -+ assert device.type == "cuda" -+ assert device.index is not None - temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - - # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, - # and `Memory access fault by GPU node-2` will EAT you. - self.temp_dq = temp_dq -- self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) -+ self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) - - def forward(self, x, force_cuda=False): - output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) From c46a223a6d9231f945d411df1beee12be4f30109 Mon Sep 17 00:00:00 2001 From: Nicholas Broad Date: Fri, 31 May 2024 09:42:14 -0700 Subject: [PATCH 016/340] single char ` addition for docs (#1989) # What does this PR do? I think this will fix the docs from being weirdly formatted. All the sections after MAX_TOP_N_TOKENS don't show up in the bar on the right (https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#maxtopntokens) ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? @merveenoyan --------- Co-authored-by: Nicolas Patry --- docs/source/basic_tutorials/launcher.md | 2 +- launcher/src/main.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index c00d2e1a..acab822e 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -125,7 +125,7 @@ Options: ## MAX_TOP_N_TOKENS ```shell --max-top-n-tokens - This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking + This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking [env: MAX_TOP_N_TOKENS=] [default: 5] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 6c4e85d6..ce2af08f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -238,7 +238,7 @@ struct Args { max_stop_sequences: usize, /// This is the maximum allowed value for clients to set `top_n_tokens`. - /// `top_n_tokens is used to return information about the the `n` most likely + /// `top_n_tokens` is used to return information about the the `n` most likely /// tokens at each generation step, instead of just the sampled token. This /// information can be used for downstream tasks like for classification or /// ranking. From 7752f1050b544748a804f5c163df84948a90bb57 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 1 Jun 2024 08:47:00 +0000 Subject: [PATCH 017/340] Fixing Phi3. --- .../models/custom_modeling/flash_llama_modeling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c0fa09fd..cef712f0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -52,7 +52,8 @@ if SYSTEM == "rocm": def load_attention(config, prefix, weights): - bias = config.attention_bias + # Only defined in granite. + bias = getattr(config, "attention_bias", False) # if specific model type, load the correct attention if config.model_type == "phi3": From b30b2a6dae3760b6030f0fd24fc1f38ae5f939bf Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 3 Jun 2024 10:36:29 +0200 Subject: [PATCH 018/340] Fixing GPTQ imports. (#1994) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/layers/linear.py | 6 ++++-- .../models/custom_modeling/flash_dbrx_modeling.py | 3 ++- .../models/custom_modeling/flash_santacoder_modeling.py | 3 ++- server/text_generation_server/utils/weights.py | 6 ++++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 570aa75c..1d131e0b 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -2,8 +2,6 @@ from typing import Optional import torch from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight if SYSTEM == "rocm": try: @@ -155,6 +153,8 @@ def get_linear(weight, bias, quantize): quant_type="nf4", ) elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2Weight + if not isinstance(weight, Exl2Weight): raise NotImplementedError( f"The passed weight is not `exl2` compatible, loader needs to be updated." @@ -165,6 +165,8 @@ def get_linear(weight, bias, quantize): linear = ExllamaQuantLinear(weight, bias) elif quantize == "gptq": + from text_generation_server.layers.gptq import GPTQWeight + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 497956e3..7967e420 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,7 +21,6 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger -from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -198,6 +197,8 @@ def _load_gqa(config, prefix: str, weights): v_stop = v_offset + (rank + 1) * kv_block_size if config.quantize in ["gptq", "awq"]: + from text_generation_server.layers.gptq import GPTQWeight + try: qweight_slice = weights._get_slice(f"{prefix}.qweight") q_qweight = qweight_slice[:, q_start:q_stop] diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c8397000..1f47550e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,7 +5,6 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.attention import ( paged_attention, attention, @@ -39,6 +38,8 @@ def load_multi_mqa( def _load_multi_mqa_gptq( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): + from text_generation_server.layers.gptq import GPTQWeight + if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: world_size = weights.process_group.size() rank = weights.process_group.rank() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 710ea680..5782de8a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,8 +7,6 @@ import torch from loguru import logger from huggingface_hub import hf_hub_download import json -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.log import log_once @@ -221,6 +219,8 @@ class Weights: def get_weights_col(self, prefix: str, quantize: str): if quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2Weight + try: q_weight = self.get_tensor(f"{prefix}.q_weight") except RuntimeError: @@ -247,6 +247,8 @@ class Weights: if quantize == "exl2": raise ValueError("get_multi_weights_col is not supported for exl2") elif quantize in ["gptq", "awq"]: + from text_generation_server.layers.gptq import GPTQWeight + try: qweight = torch.cat( [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 From b3b175568f6a103f2eab73f2c39f1959a21b2657 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 3 Jun 2024 09:32:12 +0000 Subject: [PATCH 019/340] Hotfix GPTQ. --- server/text_generation_server/layers/linear.py | 2 ++ server/text_generation_server/utils/weights.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 1d131e0b..ff99388e 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -196,6 +196,8 @@ def get_linear(weight, bias, quantize): weight.groupsize, ) elif quantize == "awq": + from text_generation_server.layers.gptq import GPTQWeight + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 5782de8a..2dfd80bf 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -154,6 +154,8 @@ class Weights: already alternating Q,K,V within the main tensor """ if quantize in ["gptq", "awq"]: + from text_generation_server.layers.gptq import GPTQWeight + try: qweight = self._get_qweight(f"{prefix}.qweight") except RuntimeError: @@ -331,6 +333,8 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2Weight + try: q_weight = self.get_tensor(f"{prefix}.q_weight") except RuntimeError: @@ -390,7 +394,11 @@ class Weights: # it would require to reorder input activations that are split unto several GPUs use_exllama = False - from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) if use_exllama: if not HAS_EXLLAMA: @@ -442,6 +450,8 @@ class Weights: use_exllama=use_exllama, ) elif quantize == "awq": + from text_generation_server.layers.gptq import GPTQWeight + bits, groupsize, _, _ = self._get_gptq_params() try: From 347ecdae3bca2c73874cf12d8690db2ffadef5d2 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 3 Jun 2024 22:07:50 +0800 Subject: [PATCH 020/340] reable xpu, broken by gptq and setuptool upgrade (#1988) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Signed-off-by: Wang, Yi A --- Dockerfile_intel | 2 +- server/requirements_intel.txt | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 server/requirements_intel.txt diff --git a/Dockerfile_intel b/Dockerfile_intel index 809992e1..9c9b5c16 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -66,7 +66,7 @@ COPY server server COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ - pip install -r requirements_cuda.txt && \ + pip install -r requirements_intel.txt && \ pip install ".[accelerate, peft, outlines]" --no-cache-dir ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt new file mode 100644 index 00000000..5751bf81 --- /dev/null +++ b/server/requirements_intel.txt @@ -0,0 +1,48 @@ +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" +py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" +setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" From d51f2c465fb6f24167e05915784b6287685d8921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 3 Jun 2024 07:27:22 +0000 Subject: [PATCH 021/340] router: send the input as chunks to the backend Before this change, the generation input was sent to the backend as a single string, encoding images as Base64 and packing them in Markdown-style links. This change adds a new chunked input representation that separates text chunks from images chunks. Image chunks contain binary data (for smaller message sizes) and the image's MIME type. The stringly-typed inputs are still sent to support backends that do not support chunked inputs yet. --- Cargo.toml | 1 + benchmark/src/generation.rs | 7 +- proto/generate.proto | 25 +++++- router/Cargo.toml | 2 +- router/client/Cargo.toml | 1 + router/client/src/client.rs | 31 ++++++- router/client/src/lib.rs | 35 +++++++- router/src/config.rs | 14 ++-- router/src/health.rs | 7 +- router/src/queue.rs | 9 +- router/src/validation.rs | 158 ++++++++++++++++++++++++------------ 11 files changed, 221 insertions(+), 69 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aafc8435..8b03a3ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" [workspace.dependencies] +base64 = "0.22.0" tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index ea7c9778..8c07e62b 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,7 +1,7 @@ use std::time::{Duration, Instant}; use text_generation_client::{ - Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, - StoppingCriteriaParameters, + Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request, + ShardedClient, StoppingCriteriaParameters, }; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; @@ -142,6 +142,9 @@ async fn prefill( .map(|id| Request { id: id.into(), prefill_logprobs: false, + input_chunks: Some(Input { + chunks: vec![Chunk::Text(sequence.clone()).into()], + }), inputs: sequence.clone(), truncate: sequence_length, parameters: Some(parameters.clone()), diff --git a/proto/generate.proto b/proto/generate.proto index 9921faea..e6993f03 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,6 +51,27 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +message Image { + /// Binary image data. + bytes data = 1; + + /// Image MIME type. + string mimetype = 2; +} + +message InputChunk { + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } +} + +message Input { + repeated InputChunk chunks = 1; + } + enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; @@ -95,7 +116,9 @@ message StoppingCriteriaParameters { message Request { /// Request ID uint64 id = 1; - /// The generation context + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks string inputs = 2; /// Context truncation uint32 truncate = 3; diff --git a/router/Cargo.toml b/router/Cargo.toml index fdfe1a5b..2e6264be 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -49,7 +49,7 @@ futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" image = "0.25.1" -base64 = "0.22.0" +base64 = { workspace = true } [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index d0131784..abbde82d 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -6,6 +6,7 @@ authors.workspace = true homepage.workspace = true [dependencies] +base64 = { workspace = true } futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } prost = "^0.12" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index e8035106..8b509d6b 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,13 +1,17 @@ /// Single shard Client use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v2::*; -use crate::Result; +use crate::{Chunk, Result}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; use grpc_metadata::InjectTelemetryContext; use std::cmp::min; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { @@ -113,18 +117,39 @@ impl Client { while n_tokens < max_prefill_tokens { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. let mut inputs = String::new(); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. - inputs.push_str("![]()"); + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); } requests.push(Request { id: 0, - // We truncate the input on the server side to be sure that it has the correct size + input_chunks: Some(Input { + chunks: input_chunks, + }), inputs, + // We truncate the input on the server side to be sure that it has the correct size truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 6782d9ff..9e9ef13b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -5,11 +5,14 @@ mod client; mod pb; mod sharded_client; +use base64::{engine::general_purpose::STANDARD, Engine}; pub use client::Client; +pub use pb::generate::v2::input_chunk::Chunk; pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::Image; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; @@ -44,3 +47,33 @@ impl From for ClientError { } pub type Result = std::result::Result; + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c.chunk { + Some(Chunk::Text(text)) => output.push_str(text), + Some(Chunk::Image(Image { data, mimetype })) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + // We don't create empty chunks, so this should be unreachable. + None => unreachable!("Chunks should never be empty"), + }); + output + } +} diff --git a/router/src/config.rs b/router/src/config.rs index d27b1136..29fefd5b 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct LlavaNext { - text_config: TextConfig, - vision_config: VisionConfig, - image_grid_pinpoints: Vec<(usize, usize)>, + pub(crate) text_config: TextConfig, + pub(crate) vision_config: VisionConfig, + pub(crate) image_grid_pinpoints: Vec<(usize, usize)>, } fn get_anyres_image_grid_shape( @@ -119,13 +119,13 @@ impl Idefics2 { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct PaliTextConfig { - num_image_tokens: usize, + pub(crate) num_image_tokens: usize, } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Paligemma { - text_config: PaliTextConfig, + pub(crate) text_config: PaliTextConfig, } impl Paligemma { @@ -175,8 +175,8 @@ pub struct TextConfig {} #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct VisionConfig { - image_size: usize, - patch_size: usize, + pub(crate) image_size: usize, + pub(crate) patch_size: usize, } #[cfg(test)] diff --git a/router/src/health.rs b/router/src/health.rs index b05b3094..121255b9 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,9 +1,9 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use text_generation_client::GrammarType as ProtoGrammarType; use text_generation_client::{ - Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, + Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; +use text_generation_client::{Chunk, GrammarType as ProtoGrammarType}; // Note: Request ids and batch ids cannot collide. const LIVENESS_ID: u64 = u64::MAX; @@ -33,6 +33,9 @@ impl Health { // Dummy batch of 1 token and 1 generated token let liveness_request = Request { id: LIVENESS_ID, + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), inputs: "liveness".to_string(), truncate: 10, prefill_logprobs: false, diff --git a/router/src/queue.rs b/router/src/queue.rs index a32673dd..40692ffc 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -4,6 +4,8 @@ use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; +use text_generation_client::ChunksToString; +use text_generation_client::Input; use text_generation_client::{Batch, Request}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; @@ -278,7 +280,10 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.clone(), + input_chunks: Some(Input { + chunks: entry.request.inputs.clone(), + }), + inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), @@ -366,7 +371,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: String::new(), + inputs: vec![], input_length: 0, truncate: 0, decoder_input_details: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index 96b6cb27..863bb99b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -7,7 +7,8 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, + Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters, + StoppingCriteriaParameters, }; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -89,7 +90,7 @@ impl Validation { &self, inputs: String, truncate: Option, - ) -> Result, ValidationError> { + ) -> Result)>, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -115,7 +116,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(String, usize, u32), ValidationError> { + ) -> Result<(Vec, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -178,7 +179,11 @@ impl Validation { // )); } - Ok((inputs, input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs).into()], + input_length, + max_new_tokens, + )) } } @@ -465,7 +470,7 @@ fn format_to_mimetype(format: ImageFormat) -> String { .to_string() } -fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { +fn fetch_image(input: &str) -> Result<(Vec, String, usize, usize), ValidationError> { if input.starts_with("![](http://") || input.starts_with("![](https://") { let url = &input["![](".len()..input.len() - 1]; let data = reqwest::blocking::get(url)?.bytes()?; @@ -476,9 +481,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { let height: usize = img.height().try_into()?; let width: usize = img.width().try_into()?; let mimetype = format_to_mimetype(format); - let encoded = STANDARD.encode(data); - let data_uri = format!("![](data:{mimetype};base64,{encoded})"); - Ok((data_uri, height, width)) + Ok((data.to_vec(), mimetype, height, width)) } else if input.starts_with("![](data:") { // Remove ![](....) let content = &input["![](data:".len()..input.len() - 1]; @@ -495,9 +498,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { let data = STANDARD.decode(content["base64,".len()..].as_bytes())?; let img = if let Some(format) = format_from_mimetype(mimetype) { - ImageReader::with_format(Cursor::new(data), format).decode()? + ImageReader::with_format(Cursor::new(&data), format).decode()? } else { - ImageReader::new(Cursor::new(data)) + ImageReader::new(Cursor::new(&data)) .with_guessed_format() .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))? .decode()? @@ -505,7 +508,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { let height: usize = img.height().try_into()?; let width: usize = img.width().try_into()?; - Ok((input.to_string(), height, width)) + Ok((data, mimetype.to_string(), height, width)) } else { Err(ValidationError::InvalidImageContent(input.to_string())) } @@ -513,113 +516,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { /// Get input length and optionally truncate it fn prepare_input( - mut inputs: String, + inputs: String, _truncate: Option, tokenizer: &Tokenizer, config: &Option, -) -> Result<(tokenizers::Encoding, String), ValidationError> { +) -> Result<(tokenizers::Encoding, Vec), ValidationError> { static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); - let tokenizer_query = match config { + let (tokenizer_query, input_chunks) = match config { Some(Config::LlavaNext(config)) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let slots = config.get_number_of_features(height, width); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); tokenizer_query.push_str(&"".repeat(slots)); - modified_inputs.push_str(&image_uri); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } Some(Config::Paligemma(config)) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let slots = config.get_number_of_features(height, width); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); tokenizer_query.push_str(&"".repeat(slots)); - modified_inputs.push_str(&image_uri); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } Some(Config::Idefics2(config)) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let slots = config.get_number_of_features(height, width); tokenizer_query.push_str(""); tokenizer_query.push_str(&"".repeat(slots)); tokenizer_query.push_str(""); - modified_inputs.push_str(&image_uri); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } Some(Config::Idefics) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, _height, _width) = + fetch_image(&inputs[chunk_start..chunk_end])?; let slots = 1; tokenizer_query.push_str(&"".repeat(slots)); - modified_inputs.push_str(&image_uri); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } - _ => inputs.clone(), + _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), }; // Get the number of tokens in the input @@ -627,18 +627,18 @@ fn prepare_input( .encode(tokenizer_query, true) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - Ok((encoding, inputs)) + Ok((encoding, input_chunks)) } type TokenizerRequest = ( (String, Option), - oneshot::Sender>, + oneshot::Sender), ValidationError>>, Span, ); #[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { - pub inputs: String, + pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, @@ -714,6 +714,7 @@ pub enum ValidationError { #[cfg(test)] mod tests { use super::*; + use crate::config::{PaliTextConfig, Paligemma}; use crate::default_parameters; use crate::tests::get_tokenizer; @@ -964,4 +965,61 @@ mod tests { assert_eq!(valid_request.top_n_tokens, 0); } + + static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw=="; + + #[tokio::test] + async fn test_prepare_input_chunks() { + let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); + + let tokenizer = Some(get_tokenizer().await); + + let max_best_of = 2; + let max_stop_sequence = 3; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; + let disable_grammar_support = true; + let workers = 1; + let config = Config::Paligemma(Paligemma { + text_config: PaliTextConfig { + num_image_tokens: 1, + }, + }); + let validation = Validation::new( + workers, + tokenizer, + Some(config), + max_best_of, + max_stop_sequence, + max_top_n_tokens, + max_input_length, + max_total_tokens, + disable_grammar_support, + ); + + let chunks = match validation + .tokenize( + format!("test![](data:image/gif;base64,{})", PIXEL_GIF), + None, + ) + .await + { + Ok(Some((_encoding, chunks))) => chunks, + _ => panic!("Unexpected tokenization failure"), + }; + + assert!( + chunks + == vec![ + Chunk::Text("test".to_string()).into(), + Chunk::Image(Image { + data: pixel_data.clone(), + mimetype: "image/gif".to_string() + }) + .into() + ], + "Failed to process images", + ); + } } From 75aed8aed5d2e83e80e1cf2c97ee24f852486467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 4 Jun 2024 14:26:07 +0200 Subject: [PATCH 022/340] Fix Phi-2 with `tp>1` (#2003) # What does this PR do? We were using the wrong parallelism in the up-projection. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../models/custom_modeling/flash_phi_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index af3206dd..53d3ea42 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -238,7 +238,7 @@ class PhiMLP(nn.Module): ) # llama weights are up_proj and down_proj and bias=False - self.up_proj = TensorParallelRowLinear.load( + self.up_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.fc1", weights=weights, From 63de9ff020bb05adc58a88ab6e5981a91b3c8269 Mon Sep 17 00:00:00 2001 From: Emmanuel Ferdman Date: Tue, 4 Jun 2024 15:26:35 +0300 Subject: [PATCH 023/340] fix: update triton implementation reference (#2002) # What does this PR do? PR #1986 moved the location of the `flash_attn_triton.py` file. This PR adjusts sources to changes. ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. Signed-off-by: Emmanuel Ferdman --- docs/source/installation_amd.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index d70953ae..bf7f9c75 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -27,7 +27,7 @@ TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you wo ## Flash attention implementation -Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py). +Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py). By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. From 184c89fd55c29b7877bead7aa92edd8844d3cea3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 4 Jun 2024 15:56:56 +0200 Subject: [PATCH 024/340] feat: add SchedulerV3 (#1996) - Refactor code to allow supporting multiple versions of the generate.proto at the same time - Add v3/generate.proto (ISO to generate.proto for now but allow for future changes without impacting v2 backends) - Add Schedule trait to abstract queuing and batching mechanisms that will be different in the future - Add SchedulerV2/V3 impl --- benchmark/src/generation.rs | 7 +- benchmark/src/lib.rs | 2 +- benchmark/src/main.rs | 2 +- proto/generate.proto | 25 +- proto/v3/generate.proto | 259 ++++ router/client/Cargo.toml | 1 + router/client/build.rs | 18 +- router/client/src/lib.rs | 44 +- router/client/src/pb/.gitignore | 1 - router/client/src/pb/generate.v2.rs | 647 +++++++++ router/client/src/pb/mod.rs | 6 + router/client/src/v2/client.rs | 258 ++++ router/client/src/v2/mod.rs | 13 + router/client/src/{ => v2}/sharded_client.rs | 70 +- router/client/src/{ => v3}/client.rs | 16 +- router/client/src/v3/mod.rs | 13 + router/client/src/v3/sharded_client.rs | 254 ++++ router/src/health.rs | 75 -- router/src/infer/health.rs | 34 + router/src/infer/mod.rs | 522 ++++++++ router/src/infer/v2/mod.rs | 4 + router/src/infer/v2/queue.rs | 667 ++++++++++ .../src/{infer.rs => infer/v2/scheduler.rs} | 550 +------- router/src/infer/v3/mod.rs | 4 + router/src/{ => infer/v3}/queue.rs | 72 +- router/src/infer/v3/scheduler.rs | 1177 +++++++++++++++++ router/src/lib.rs | 19 +- router/src/main.rs | 71 +- router/src/server.rs | 401 ++++-- router/src/validation.rs | 83 +- server/Makefile | 4 +- 31 files changed, 4428 insertions(+), 891 deletions(-) create mode 100644 proto/v3/generate.proto delete mode 100644 router/client/src/pb/.gitignore create mode 100644 router/client/src/pb/generate.v2.rs create mode 100644 router/client/src/pb/mod.rs create mode 100644 router/client/src/v2/client.rs create mode 100644 router/client/src/v2/mod.rs rename router/client/src/{ => v2}/sharded_client.rs (75%) rename router/client/src/{ => v3}/client.rs (90%) create mode 100644 router/client/src/v3/mod.rs create mode 100644 router/client/src/v3/sharded_client.rs delete mode 100644 router/src/health.rs create mode 100644 router/src/infer/health.rs create mode 100644 router/src/infer/mod.rs create mode 100644 router/src/infer/v2/mod.rs create mode 100644 router/src/infer/v2/queue.rs rename router/src/{infer.rs => infer/v2/scheduler.rs} (76%) create mode 100644 router/src/infer/v3/mod.rs rename router/src/{ => infer/v3}/queue.rs (90%) create mode 100644 router/src/infer/v3/scheduler.rs diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 8c07e62b..27b74249 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,8 +1,9 @@ use std::time::{Duration, Instant}; -use text_generation_client::{ - Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request, - ShardedClient, StoppingCriteriaParameters, +use text_generation_client::v3::{ + Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient, + StoppingCriteriaParameters, }; +use text_generation_client::{Chunk, ClientError, Input}; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 638c6514..c33d64e6 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; +use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2d89e045..b9d80b7a 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -4,7 +4,7 @@ /// and: https://github.com/orhun/rust-tui-template use clap::Parser; use std::path::Path; -use text_generation_client::ShardedClient; +use text_generation_client::v3::ShardedClient; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; diff --git a/proto/generate.proto b/proto/generate.proto index e6993f03..9921faea 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,27 +51,6 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} -message Image { - /// Binary image data. - bytes data = 1; - - /// Image MIME type. - string mimetype = 2; -} - -message InputChunk { - oneof chunk { - /// Plain text data - string text = 1; - /// Image data - Image image = 2; - } -} - -message Input { - repeated InputChunk chunks = 1; - } - enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; @@ -116,9 +95,7 @@ message StoppingCriteriaParameters { message Request { /// Request ID uint64 id = 1; - /// The generation context as chunks - Input input_chunks = 8; - /// The generation context, stringified input_chunks + /// The generation context string inputs = 2; /// Context truncation uint32 truncate = 3; diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto new file mode 100644 index 00000000..ca2908c9 --- /dev/null +++ b/proto/v3/generate.proto @@ -0,0 +1,259 @@ +syntax = "proto3"; + +package generate.v3; + +service TextGenerationService { + /// Model Info + rpc Info (InfoRequest) returns (InfoResponse) {} + /// Service discovery + rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} + /// Empties batch cache + rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup (WarmupRequest) returns (WarmupResponse); + /// Prefill batch and decode first token + rpc Prefill (PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health (HealthRequest) returns (HealthResponse); +} + +message HealthRequest {} +message HealthResponse {} + +/// Empty request +message InfoRequest {} + +message InfoResponse { + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; + optional uint32 window_size = 4; + uint32 speculate = 5; +} + +/// Empty request +message ServiceDiscoveryRequest {} + +message ServiceDiscoveryResponse { + /// Other shards urls + repeated string urls = 1; +} + +message ClearCacheRequest { + /// Optional batch id + optional uint64 id = 1; +} + +/// Empty response +message ClearCacheResponse {} + +message Image { + /// Binary image data. + bytes data = 1; + + /// Image MIME type. + string mimetype = 2; +} + +message InputChunk { + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } +} + +message Input { + repeated InputChunk chunks = 1; + } + +enum GrammarType { + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; +} + +message NextTokenChooserParameters { + /// exponential scaling output probability distribution + float temperature = 1; + /// restricting to the k highest probability elements + uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; + /// apply sampling on the logits + bool do_sample = 5; + /// random seed for sampling + uint64 seed = 6; + /// repetition penalty + float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; +} + +message StoppingCriteriaParameters { + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; + /// Ignore end of sequence token + /// used for benchmarking + bool ignore_eos_token = 3; +} + +message Request { + /// Request ID + uint64 id = 1; + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks + string inputs = 2; + /// Context truncation + uint32 truncate = 3; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; +} + +message Batch { + /// Batch ID + uint64 id = 1; + /// Individual requests + repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; +} + +message CachedBatch { + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; +} + +enum FinishReason { + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; +} + +message GeneratedText { + /// Output + string text = 1; + /// Number of generated tokens + uint32 generated_tokens = 2; + /// Finish reason + FinishReason finish_reason = 3; + /// Seed + optional uint64 seed = 4; +} + +message Tokens { + /// Token IDs + repeated uint32 ids = 1; + /// Logprobs + repeated float logprobs = 2; + /// tokens + repeated string texts = 3; + /// special + repeated bool is_special = 4; +} + +message Generation { + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + Tokens prefill_tokens = 2; + Tokens tokens = 3; + /// Complete generated text + optional GeneratedText generated_text = 4; + /// Top tokens + repeated Tokens top_tokens = 5; +} + +message FilterBatchRequest { + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated uint64 request_ids = 2; +} + +message FilterBatchResponse { + /// Filtered Batch (cached) + CachedBatch batch = 1; +} + + +message PrefillRequest { + /// Batch + Batch batch = 1; +} + +message PrefillResponse { + /// Generation + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; +} + +message DecodeRequest { + /// Cached batches + repeated CachedBatch batches = 1; +} + +message DecodeResponse { + /// Decodes + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; +} + +message WarmupRequest { + /// Batch to warmup on + Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; +} + +message WarmupResponse { + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; +} diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index abbde82d..db423c4b 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -6,6 +6,7 @@ authors.workspace = true homepage.workspace = true [dependencies] +async-trait = "^0.1" base64 = { workspace = true } futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } diff --git a/router/client/build.rs b/router/client/build.rs index 497be545..bcfab74f 100644 --- a/router/client/build.rs +++ b/router/client/build.rs @@ -1,19 +1,31 @@ use std::fs; fn main() -> Result<(), Box> { - println!("cargo:rerun-if-changed=../../proto/generate.proto"); - fs::create_dir("src/pb").unwrap_or(()); + println!("cargo:rerun-if-changed=../../proto/**"); + fs::create_dir_all("src/v2/pb").unwrap_or(()); let mut config = prost_build::Config::new(); config.protoc_arg("--experimental_allow_proto3_optional"); tonic_build::configure() .build_client(true) .build_server(false) - .out_dir("src/pb") + .out_dir("src/v2/pb") .include_file("mod.rs") .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + fs::create_dir_all("src/v3/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/v3/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + Ok(()) } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 9e9ef13b..45bee10c 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -1,25 +1,35 @@ //! Text Generation gRPC client library -mod client; -#[allow(clippy::derive_partial_eq_without_eq)] -mod pb; -mod sharded_client; +pub mod v2; +pub mod v3; +use async_trait::async_trait; use base64::{engine::general_purpose::STANDARD, Engine}; -pub use client::Client; -pub use pb::generate::v2::input_chunk::Chunk; -pub use pb::generate::v2::HealthResponse; -pub use pb::generate::v2::Image; -pub use pb::generate::v2::InfoResponse as ShardInfo; -pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, -}; -pub use sharded_client::ShardedClient; use thiserror::Error; use tonic::transport; use tonic::Status; +pub use v3::{Chunk, Image, Input, InputChunk}; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + #[derive(Error, Debug, Clone)] pub enum ClientError { #[error("Could not connect to Text Generation server: {0}")] @@ -46,8 +56,6 @@ impl From for ClientError { } } -pub type Result = std::result::Result; - // Small convenience re-wrapping of `Chunk`. impl From for InputChunk { fn from(chunk: Chunk) -> Self { @@ -77,3 +85,7 @@ impl ChunksToString for Vec { output } } + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/router/client/src/pb/.gitignore b/router/client/src/pb/.gitignore deleted file mode 100644 index 6f5f3d11..00000000 --- a/router/client/src/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.rs diff --git a/router/client/src/pb/generate.v2.rs b/router/client/src/pb/generate.v2.rs new file mode 100644 index 00000000..1a206360 --- /dev/null +++ b/router/client/src/pb/generate.v2.rs @@ -0,0 +1,647 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthResponse {} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoResponse { + #[prost(bool, tag = "1")] + pub requires_padding: bool, + #[prost(string, tag = "2")] + pub dtype: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub device_type: ::prost::alloc::string::String, + #[prost(uint32, optional, tag = "4")] + pub window_size: ::core::option::Option, + #[prost(uint32, tag = "5")] + pub speculate: u32, +} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryResponse { + /// / Other shards urls + #[prost(string, repeated, tag = "1")] + pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheRequest { + /// / Optional batch id + #[prost(uint64, optional, tag = "1")] + pub id: ::core::option::Option, +} +/// / Empty response +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheResponse {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NextTokenChooserParameters { + /// / exponential scaling output probability distribution + #[prost(float, tag = "1")] + pub temperature: f32, + /// / restricting to the k highest probability elements + #[prost(uint32, tag = "2")] + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "3")] + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "4")] + pub typical_p: f32, + /// / apply sampling on the logits + #[prost(bool, tag = "5")] + pub do_sample: bool, + /// / random seed for sampling + #[prost(uint64, tag = "6")] + pub seed: u64, + /// / repetition penalty + #[prost(float, tag = "7")] + pub repetition_penalty: f32, + /// / frequency penalty + #[prost(float, tag = "9")] + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + #[prost(bool, tag = "8")] + pub watermark: bool, + /// / grammar (applied if not empty) + #[prost(string, tag = "10")] + pub grammar: ::prost::alloc::string::String, + /// / grammar type + #[prost(enumeration = "GrammarType", tag = "11")] + pub grammar_type: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StoppingCriteriaParameters { + /// / Maximum number of generated tokens + #[prost(uint32, tag = "1")] + pub max_new_tokens: u32, + /// / Optional stopping sequences + #[prost(string, repeated, tag = "2")] + pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / Ignore end of sequence token + /// / used for benchmarking + #[prost(bool, tag = "3")] + pub ignore_eos_token: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Request { + /// / Request ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / The generation context + #[prost(string, tag = "2")] + pub inputs: ::prost::alloc::string::String, + /// / Context truncation + #[prost(uint32, tag = "3")] + pub truncate: u32, + /// / Next Token Chooser Parameters + #[prost(message, optional, tag = "4")] + pub parameters: ::core::option::Option, + /// / Stopping Criteria Parameters + #[prost(message, optional, tag = "5")] + pub stopping_parameters: ::core::option::Option, + /// / Return prefill logprobs + #[prost(bool, tag = "6")] + pub prefill_logprobs: bool, + /// / Return most likely n tokens + #[prost(uint32, tag = "7")] + pub top_n_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Batch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests + #[prost(message, repeated, tag = "2")] + pub requests: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CachedBatch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests ids + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GeneratedText { + /// / Output + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + /// / Number of generated tokens + #[prost(uint32, tag = "2")] + pub generated_tokens: u32, + /// / Finish reason + #[prost(enumeration = "FinishReason", tag = "3")] + pub finish_reason: i32, + /// / Seed + #[prost(uint64, optional, tag = "4")] + pub seed: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Tokens { + /// / Token IDs + #[prost(uint32, repeated, tag = "1")] + pub ids: ::prost::alloc::vec::Vec, + /// / Logprobs + #[prost(float, repeated, tag = "2")] + pub logprobs: ::prost::alloc::vec::Vec, + /// / tokens + #[prost(string, repeated, tag = "3")] + pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / special + #[prost(bool, repeated, tag = "4")] + pub is_special: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Generation { + /// / Request ID + #[prost(uint64, tag = "1")] + pub request_id: u64, + /// / Prefill tokens (optional) + #[prost(message, optional, tag = "2")] + pub prefill_tokens: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub tokens: ::core::option::Option, + /// / Complete generated text + #[prost(message, optional, tag = "4")] + pub generated_text: ::core::option::Option, + /// / Top tokens + #[prost(message, repeated, tag = "5")] + pub top_tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchRequest { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub batch_id: u64, + /// / Requests to keep + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchResponse { + /// / Filtered Batch (cached) + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillRequest { + /// / Batch + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillResponse { + /// / Generation + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeRequest { + /// / Cached batches + #[prost(message, repeated, tag = "1")] + pub batches: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeResponse { + /// / Decodes + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupRequest { + /// / Batch to warmup on + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, + #[prost(uint32, tag = "2")] + pub max_input_length: u32, + #[prost(uint32, tag = "3")] + pub max_prefill_tokens: u32, + #[prost(uint32, tag = "4")] + pub max_total_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupResponse { + /// / Maximum number of tokens supported by the model + #[prost(uint32, optional, tag = "1")] + pub max_supported_total_tokens: ::core::option::Option, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum GrammarType { + None = 0, + Json = 1, + Regex = 2, +} +impl GrammarType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + GrammarType::None => "GRAMMAR_TYPE_NONE", + GrammarType::Json => "GRAMMAR_TYPE_JSON", + GrammarType::Regex => "GRAMMAR_TYPE_REGEX", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GRAMMAR_TYPE_NONE" => Some(Self::None), + "GRAMMAR_TYPE_JSON" => Some(Self::Json), + "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FinishReason { + Length = 0, + EosToken = 1, + StopSequence = 2, +} +impl FinishReason { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + FinishReason::Length => "FINISH_REASON_LENGTH", + FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", + FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FINISH_REASON_LENGTH" => Some(Self::Length), + "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), + "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), + _ => None, + } + } +} +/// Generated client implementations. +pub mod text_generation_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct TextGenerationServiceClient { + inner: tonic::client::Grpc, + } + impl TextGenerationServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TextGenerationServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TextGenerationServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// / Model Info + pub async fn info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Info", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info")); + self.inner.unary(req, path, codec).await + } + /// / Service discovery + pub async fn service_discovery( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ServiceDiscovery", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "generate.v2.TextGenerationService", + "ServiceDiscovery", + ), + ); + self.inner.unary(req, path, codec).await + } + /// / Empties batch cache + pub async fn clear_cache( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ClearCache", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"), + ); + self.inner.unary(req, path, codec).await + } + /// / Remove requests from a cached batch + pub async fn filter_batch( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/FilterBatch", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"), + ); + self.inner.unary(req, path, codec).await + } + /// / Warmup the model and compute max cache size + pub async fn warmup( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Warmup", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup")); + self.inner.unary(req, path, codec).await + } + /// / Prefill batch and decode first token + pub async fn prefill( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Prefill", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill")); + self.inner.unary(req, path, codec).await + } + /// / Decode token for a list of prefilled batches + pub async fn decode( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Decode", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode")); + self.inner.unary(req, path, codec).await + } + /// / Health check + pub async fn health( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Health", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health")); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/router/client/src/pb/mod.rs b/router/client/src/pb/mod.rs new file mode 100644 index 00000000..095ead1f --- /dev/null +++ b/router/client/src/pb/mod.rs @@ -0,0 +1,6 @@ +// This file is @generated by prost-build. +pub mod generate { + pub mod v2 { + include!("generate.v2.rs"); + } +} diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs new file mode 100644 index 00000000..9a2e6ac7 --- /dev/null +++ b/router/client/src/v2/client.rs @@ -0,0 +1,258 @@ +/// Single shard Client +use crate::v2::pb; +use crate::{ClientError, Result}; + +use crate::WARMUP_IMAGE_BASE64; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v2::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v2 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/router/client/src/v2/mod.rs b/router/client/src/v2/mod.rs new file mode 100644 index 00000000..6b14b9f3 --- /dev/null +++ b/router/client/src/v2/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/router/client/src/sharded_client.rs b/router/client/src/v2/sharded_client.rs similarity index 75% rename from router/client/src/sharded_client.rs rename to router/client/src/v2/sharded_client.rs index e1e52d59..7b24aec3 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/v2/sharded_client.rs @@ -1,10 +1,17 @@ -use crate::client::{DecodeTimings, PrefillTimings}; /// Multi shard Client -use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; +use crate::{v2, Health, ShardInfo}; use crate::{ClientError, Result}; + +use crate::v2::InfoResponse; +use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; +use v2::client::{DecodeTimings, PrefillTimings}; +use v2::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client @@ -47,7 +54,7 @@ impl ShardedClient { .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap() + join_all(futures).await.pop().unwrap().map(ShardInfo::from) } /// GRPC health check @@ -185,3 +192,60 @@ impl ShardedClient { Ok((generations, next_batch, timings)) } } + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/router/client/src/client.rs b/router/client/src/v3/client.rs similarity index 90% rename from router/client/src/client.rs rename to router/client/src/v3/client.rs index 8b509d6b..1f3a89a0 100644 --- a/router/client/src/client.rs +++ b/router/client/src/v3/client.rs @@ -1,17 +1,16 @@ +use crate::v3::{pb, Chunk}; +use crate::{ClientError, Result, WARMUP_IMAGE_BASE64}; /// Single shard Client -use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; -use crate::pb::generate::v2::*; -use crate::{Chunk, Result}; use base64::engine::general_purpose::STANDARD; use base64::Engine; use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; use std::cmp::min; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; -static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; - /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { @@ -46,7 +45,9 @@ impl Client { #[instrument(skip(self))] pub async fn service_discovery(&mut self) -> Result> { let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); - let response = self.stub.service_discovery(request).await?; + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; let urls = response .into_inner() .urls @@ -133,6 +134,7 @@ impl Client { // Send stringly-typed inputs for compatibility for backends that haven't // been updated to support chunks. + let mut inputs = String::new(); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); if n_tokens == 0 { @@ -145,10 +147,10 @@ impl Client { requests.push(Request { id: 0, + inputs, input_chunks: Some(Input { chunks: input_chunks, }), - inputs, // We truncate the input on the server side to be sure that it has the correct size truncate, // Set sampling parameters to also take these ops into account in the max memory diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs new file mode 100644 index 00000000..4a1296a2 --- /dev/null +++ b/router/client/src/v3/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs new file mode 100644 index 00000000..9b4f74d8 --- /dev/null +++ b/router/client/src/v3/sharded_client.rs @@ -0,0 +1,254 @@ +/// Multi shard Client +use crate::{v3, Health, ShardInfo}; +use crate::{ClientError, Result}; + +use crate::v3::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; +use v3::client::{DecodeTimings, PrefillTimings}; +use v3::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/router/src/health.rs b/router/src/health.rs deleted file mode 100644 index 121255b9..00000000 --- a/router/src/health.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use text_generation_client::{ - Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, -}; -use text_generation_client::{Chunk, GrammarType as ProtoGrammarType}; - -// Note: Request ids and batch ids cannot collide. -const LIVENESS_ID: u64 = u64::MAX; -const BATCH_ID: u64 = u64::MAX; - -#[derive(Clone, Debug)] -pub(crate) struct Health { - client: ShardedClient, - generation_health: Arc, -} - -impl Health { - pub(crate) fn new(client: ShardedClient, generation_health: Arc) -> Self { - Self { - client, - generation_health, - } - } - - pub(crate) async fn check(&mut self) -> bool { - if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards are answering gRPC calls - self.client.health().await.is_ok() - } else { - // Generation is unhealthy or have not sent any generation request yet - - // Dummy batch of 1 token and 1 generated token - let liveness_request = Request { - id: LIVENESS_ID, - input_chunks: Some(Input { - chunks: vec![Chunk::Text("liveness".into()).into()], - }), - inputs: "liveness".to_string(), - truncate: 10, - prefill_logprobs: false, - parameters: Some(NextTokenChooserParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - typical_p: 1.0, - do_sample: false, - seed: 0, - repetition_penalty: 1.0, - frequency_penalty: 0.0, - watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, - }), - stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: 1, - stop_sequences: vec![], - ignore_eos_token: false, - }), - top_n_tokens: 0, - }; - let batch = Batch { - id: BATCH_ID, - requests: vec![liveness_request], - size: 1, - max_tokens: 2, - }; - // Skips the queue - let value = self.client.prefill(batch).await.is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value - } - } -} diff --git a/router/src/infer/health.rs b/router/src/infer/health.rs new file mode 100644 index 00000000..4320c1a4 --- /dev/null +++ b/router/src/infer/health.rs @@ -0,0 +1,34 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use text_generation_client::Health; + +#[derive(Clone)] +pub(crate) struct HealthCheck { + client: Arc, + generation_health: Arc, +} + +impl HealthCheck { + pub(crate) fn new( + client: Arc, + generation_health: Arc, + ) -> Self { + Self { + client, + generation_health, + } + } + + pub(crate) async fn check(&mut self) -> bool { + let value = if self.generation_health.load(Ordering::SeqCst) { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + value + } +} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs new file mode 100644 index 00000000..20630c1b --- /dev/null +++ b/router/src/infer/mod.rs @@ -0,0 +1,522 @@ +mod health; +pub(crate) mod v2; +pub(crate) mod v3; + +pub(crate) use health::HealthCheck; + +use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; +use crate::{ + ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, +}; +use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; +use futures::future::try_join_all; +use minijinja::{Environment, ErrorKind, Template}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use thiserror::Error; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; +use tracing::instrument; + +pub(crate) trait Scheduler { + fn schedule( + &self, + request: ValidGenerateRequest, + permit: OwnedSemaphorePermit, + ) -> Result; +} + +/// Inference struct +#[derive(Clone)] +pub struct Infer { + /// Validation + validation: Validation, + /// Request scheduler + scheduler: Arc, + /// Chat template + chat_template: Option, + /// Inference limit + limit_concurrent_requests: Arc, +} + +impl Infer { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + scheduler: Arc, + validation: Validation, + max_concurrent_requests: usize, + tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, + ) -> Self { + let chat_template = tokenizer_config + .chat_template + .or(processor_config.chat_template) + .and_then(|t| match t { + ChatTemplateVersions::Single(template) => Some(template), + ChatTemplateVersions::Multiple(templates) => templates + .into_iter() + .find(|t| t.name == "default") + .map(|t| t.template), + }) + .map(|t| { + // .strip() is not supported in minijinja + // .capitalize() is not supported in minijinja but we can use | capitalize + let t = t + .replace(".strip()", " | trim") + .replace(".capitalize()", " | capitalize"); + ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) + }); + + // Inference limit with a semaphore + let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + + Self { + validation, + scheduler, + chat_template, + limit_concurrent_requests: semaphore, + } + } + + /// Add a new request to the queue and return a stream of InferStreamResponse + #[instrument(skip_all)] + pub(crate) async fn generate_stream( + &self, + request: GenerateRequest, + ) -> Result { + // Limit concurrent requests by acquiring a permit from the semaphore + let permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + // Validate request + let valid_request = self.validation.validate(request).await.map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + err + })?; + + self.scheduler.schedule(valid_request, permit) + } + + /// Tokenizer the input + #[instrument(skip_all)] + pub(crate) async fn tokenize( + &self, + request: GenerateRequest, + ) -> Result, InferError> { + // Tokenize request + let inputs = request.inputs; + let truncate = request.parameters.truncate; + let encoding = self + .validation + .tokenize(inputs, truncate) + .await + .map_err(|err| { + tracing::error!("Tokenization {err}"); + err + })?; + + // Return Encoding + Ok(encoding.map(|(encoding, _)| encoding)) + } + + /// Apply the chat template to the chat request + #[instrument(skip_all)] + pub(crate) fn apply_chat_template( + &self, + messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + self.chat_template + .as_ref() + .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? + .apply(messages, grammar_with_prompt) + .map_err(|e| { + metrics::increment_counter!("tgi_request_failure", "err" => "template"); + tracing::error!("{e}"); + e + }) + } + + /// Add a new request to the queue and return a InferResponse + #[instrument(skip_all)] + pub(crate) async fn generate( + &self, + request: GenerateRequest, + ) -> Result { + let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); + + // Create stream and keep semaphore permit as long as generate lives + let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + + // Return values + let mut result_prefill = Vec::new(); + let mut result_tokens = Vec::new(); + let mut result_top_tokens = Vec::new(); + let mut result_generated_text = None; + let mut result_start = None; + let mut result_queued = None; + + // Iterate on stream + while let Some(response) = stream.next().await { + match response? { + // Add prefill tokens + InferStreamResponse::Prefill(prefill_tokens) => { + result_prefill = prefill_tokens; + } + // Push last token + InferStreamResponse::Intermediate { token, top_tokens } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + } + // Final message + // Set return values + InferStreamResponse::End { + token, + generated_text, + start, + queued, + top_tokens, + } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + result_generated_text = Some(generated_text); + result_start = Some(start); + result_queued = Some(queued) + } + } + } + + // Check that we received a `InferStreamResponse::End` message + if let (Some(generated_text), Some(queued), Some(start)) = + (result_generated_text, result_queued, result_start) + { + Ok(InferResponse { + prefill: result_prefill, + _input_length, + tokens: result_tokens, + generated_text, + queued, + start, + top_tokens: if use_top_tokens { + result_top_tokens + } else { + Vec::new() + }, + }) + } else { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + Err(err) + } + } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with + /// the highest log probability per token + #[instrument(skip(self, request))] + pub(crate) async fn generate_best_of( + &self, + request: GenerateRequest, + best_of: usize, + ) -> Result<(InferResponse, Vec), InferError> { + // validate best_of parameter separately + let best_of = self.validation.validate_best_of(best_of)?; + + // create multiple generate requests + let mut infer_responses: Vec = + try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; + + // get the sequence with the highest log probability per token + let mut max_index = 0; + let mut max_logprob: f32 = f32::MIN; + + for (i, response) in infer_responses.iter().enumerate() { + // mean logprobs of the generated tokens + let sequence_logprob = response + .tokens + .iter() + .map(|token| token.logprob) + .sum::() + / response.tokens.len() as f32; + + // set best sequence + if sequence_logprob > max_logprob { + max_index = i; + max_logprob = sequence_logprob; + } + } + let best_response = infer_responses.remove(max_index); + Ok((best_response, infer_responses)) + } +} + +/// Raise a exception (custom function) used in the chat templates +fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) +} + +#[derive(Clone)] +struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { + fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + let mut env = Box::new(Environment::new()); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); + + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); + + Self { + template, + bos_token, + eos_token, + use_default_tool_template, + } + } + + fn apply( + &self, + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text(Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), + })); + } + } + } + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) + } +} + +pub struct ToolGrammar {} + +impl ToolGrammar { + pub fn apply( + tools: Option>, + tool_choice: Option, + ) -> Result, InferError> { + if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { + // let tool_prompt = tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .unwrap_or_else(|| panic!("Tool with name {} not found", name)) + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + return Ok(Some(tools)); + } + // Err(InferError::ToolError("No tools provided".to_string())) + Ok(None) + } +} + +/// Type alias for generation responses +pub(crate) type GenerateStreamResponse = ( + OwnedSemaphorePermit, + u32, // input_length + UnboundedReceiverStream>, +); + +#[derive(Debug)] +pub(crate) struct GeneratedText { + pub(crate) text: String, + pub(crate) generated_tokens: u32, + pub(crate) finish_reason: FinishReason, + pub(crate) seed: Option, +} + +#[derive(Debug)] +pub(crate) enum InferStreamResponse { + // Optional first message + Prefill(Vec), + // Intermediate messages + Intermediate { + token: Token, + top_tokens: Vec, + }, + // Last message + End { + token: Token, + top_tokens: Vec, + generated_text: GeneratedText, + start: Instant, + queued: Instant, + }, +} + +#[derive(Debug)] +pub(crate) struct InferResponse { + /// input_length is the input as perceived by the rust tokenizer in the + /// validation pathway. It is redundant with prefill.len() but prefill + /// has data only if the user asked for it. This will always be filled. + pub(crate) _input_length: u32, + pub(crate) prefill: Vec, + pub(crate) tokens: Vec, + pub(crate) generated_text: GeneratedText, + pub(crate) queued: Instant, + pub(crate) start: Instant, + pub(crate) top_tokens: Vec>, +} + +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), + #[error("Model is overloaded")] + Overloaded(#[from] TryAcquireError), + #[error("Input validation error: {0}")] + ValidationError(#[from] ValidationError), + #[error("Incomplete generation")] + IncompleteGeneration, + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), + #[error("Tool error: {0}")] + ToolError(String), +} + +impl InferError { + pub(crate) fn error_type(&self) -> &str { + match self { + InferError::GenerationError(_) => "generation", + InferError::Overloaded(_) => "overloaded", + InferError::ValidationError(_) => "validation", + InferError::IncompleteGeneration => "incomplete_generation", + InferError::TemplateError(_) => "template_error", + InferError::ToolError(_) => "tool_error", + } + } +} diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs new file mode 100644 index 00000000..8b4f6bab --- /dev/null +++ b/router/src/infer/v2/mod.rs @@ -0,0 +1,4 @@ +mod queue; +mod scheduler; + +pub(crate) use scheduler::SchedulerV2; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs new file mode 100644 index 00000000..3725c03e --- /dev/null +++ b/router/src/infer/v2/queue.rs @@ -0,0 +1,667 @@ +use crate::infer::{InferError, InferStreamResponse}; +use crate::validation::{ + ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +}; +use nohash_hasher::{BuildNoHashHasher, IntMap}; +use std::cmp::min; +use std::collections::VecDeque; +use text_generation_client::v2::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use text_generation_client::ChunksToString; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::Instant; +use tracing::{info_span, instrument, Span}; + +/// Queue entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: ValidGenerateRequest, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: mpsc::UnboundedSender>, + /// Span that will live as long as entry + pub span: Span, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, + /// Instant when this entry was queued + pub queue_time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, +} + +/// Request Queue +#[derive(Debug, Clone)] +pub(crate) struct Queue { + /// Channel to communicate with the background queue task + queue_sender: mpsc::UnboundedSender, +} + +impl Queue { + pub(crate) fn new( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + ) -> Self { + // Create channel + let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(queue_task( + requires_padding, + block_size, + window_size, + speculate, + queue_receiver, + )); + + Self { queue_sender } + } + + #[instrument(skip_all)] + pub(crate) fn append(&self, entry: Entry) { + // Send append command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::Append(Box::new(entry), Span::current())) + .unwrap(); + } + + // Get the next batch + #[instrument(skip(self))] + pub(crate) async fn next_batch( + &self, + min_size: Option, + max_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send next batch command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::NextBatch { + min_size, + max_size, + prefill_token_budget, + token_budget, + response_sender, + span: Span::current(), + }) + .unwrap(); + // Await on response channel + // Unwrap is safe here + response_receiver.await.unwrap() + } +} + +// Background task responsible of the queue state +async fn queue_task( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + mut receiver: mpsc::UnboundedReceiver, +) { + let mut state = State::new(requires_padding, block_size, window_size, speculate); + + while let Some(cmd) = receiver.recv().await { + match cmd { + QueueCommand::Append(entry, span) => { + span.in_scope(|| state.append(*entry)); + metrics::increment_gauge!("tgi_queue_size", 1.0); + } + QueueCommand::NextBatch { + min_size, + max_size, + prefill_token_budget, + token_budget, + response_sender, + span, + } => span.in_scope(|| { + let next_batch = + state.next_batch(min_size, max_size, prefill_token_budget, token_budget); + response_sender.send(next_batch).unwrap(); + metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + }), + } + } +} + +/// Queue State +#[derive(Debug)] +struct State { + /// Queue entries organized in a Vec + entries: VecDeque<(u64, Entry)>, + + /// Id of the next entry + next_id: u64, + + /// Id of the next batch + next_batch_id: u64, + + /// Whether the model is using padding + requires_padding: bool, + + /// Paged Attention block size + block_size: u32, + + /// Sliding window + window_size: Option, + + /// Speculation amount + speculate: u32, +} + +impl State { + fn new( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + ) -> Self { + Self { + entries: VecDeque::with_capacity(128), + next_id: 0, + next_batch_id: 0, + requires_padding, + block_size, + window_size, + speculate, + } + } + + /// Append an entry to the queue + fn append(&mut self, mut entry: Entry) { + // Create a span that will live as long as the entry is in the queue waiting to be batched + let queue_span = info_span!(parent: &entry.span, "queued"); + entry.temp_span = Some(queue_span); + + // Push entry in the queue + self.entries.push_back((self.next_id, entry)); + self.next_id += 1; + } + + // Get the next batch + fn next_batch( + &mut self, + min_size: Option, + max_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { + if self.entries.is_empty() { + tracing::debug!("No queue"); + return None; + } + + // Check if we have enough entries + if let Some(min_size) = min_size { + if self.entries.len() < min_size { + tracing::debug!("Not enough entries"); + return None; + } + } + + // Pad prefill_token_budget to be a multiple of block size + let prefill_token_budget = + ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; + + // Create span for this batch to add context to inference calls + let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); + next_batch_span.follows_from(&Span::current()); + + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + + let mut max_input_length = 0; + let mut prefill_tokens: u32 = 0; + let mut decode_tokens: u32 = 0; + + // Pop entries starting from the front of the queue + while let Some((id, mut entry)) = self.entries.pop_front() { + // Filter entries where the response receiver was dropped (== entries where the request + // was dropped by the client) + if entry.response_tx.is_closed() { + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + tracing::debug!("Dropping entry"); + continue; + } + + if self.requires_padding { + // We pad to max input length in the Python shards + // We need to take these padding tokens into the equation + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + } else { + // pad to block size + prefill_tokens += ((entry.request.input_length + self.block_size - 1) + / self.block_size) + * self.block_size; + } + + if self.requires_padding { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } else { + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; + + // pad to block size + decode_tokens += + ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; + } + + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget + { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break; + } + + tracing::debug!("Accepting entry"); + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + + batch_requests.push(Request { + id, + prefill_logprobs: entry.request.decoder_input_details, + inputs: entry.request.inputs.chunks_to_string(), + truncate: entry.request.truncate, + parameters: Some(NextTokenChooserParameters::from( + entry.request.parameters.clone(), + )), + stopping_parameters: Some(StoppingCriteriaParameters::from( + entry.request.stopping_parameters.clone(), + )), + top_n_tokens: entry.request.top_n_tokens, + }); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + batch_entries.insert(id, entry); + + // Check if max_size + if Some(batch_requests.len()) == max_size { + break; + } + } + + // Empty batch + if batch_requests.is_empty() { + tracing::debug!("Filtered out all entries"); + return None; + } + + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch_requests.len() < min_size { + // Add back entries to the queue in the correct order + for r in batch_requests.into_iter().rev() { + let id = r.id; + let entry = batch_entries.remove(&id).unwrap(); + self.entries.push_front((id, entry)); + } + + return None; + } + } + + // Final batch size + let size = batch_requests.len() as u32; + next_batch_span.record("batch_size", size); + + let batch = Batch { + id: self.next_batch_id, + requests: batch_requests, + size, + max_tokens: (prefill_tokens + decode_tokens), + }; + // Increment batch id + self.next_batch_id += 1; + + metrics::histogram!("tgi_batch_next_size", batch.size as f64); + + Some((batch_entries, batch, next_batch_span)) + } +} + +type NextBatch = (IntMap, Batch, Span); + +#[derive(Debug)] +enum QueueCommand { + Append(Box, Span), + NextBatch { + min_size: Option, + max_size: Option, + prefill_token_budget: u32, + token_budget: u32, + response_sender: oneshot::Sender>, + span: Span, + }, +} + +impl From for NextTokenChooserParameters { + fn from(value: ValidParameters) -> Self { + let (grammar, grammar_type) = match value.grammar { + None => (String::new(), GrammarType::None), + + Some(grammar) => match grammar { + ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json), + ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex), + }, + }; + + Self { + temperature: value.temperature, + top_k: value.top_k, + top_p: value.top_p, + typical_p: value.typical_p, + do_sample: value.do_sample, + seed: value.seed, + repetition_penalty: value.repetition_penalty, + frequency_penalty: value.frequency_penalty, + watermark: value.watermark, + grammar, + grammar_type: grammar_type.into(), + } + } +} + +impl From for StoppingCriteriaParameters { + fn from(value: ValidStoppingParameters) -> Self { + Self { + max_new_tokens: value.max_new_tokens, + stop_sequences: value.stop_sequences, + ignore_eos_token: value.ignore_eos_token, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tracing::info_span; + + fn default_entry() -> ( + Entry, + mpsc::UnboundedReceiver>, + ) { + let (response_tx, receiver_tx) = mpsc::unbounded_channel(); + + let entry = Entry { + request: ValidGenerateRequest { + inputs: vec![], + input_length: 0, + truncate: 0, + decoder_input_details: false, + parameters: ValidParameters { + temperature: 0.0, + top_k: 0, + top_p: 0.0, + typical_p: 0.0, + do_sample: false, + seed: 0, + repetition_penalty: 0.0, + frequency_penalty: 0.0, + watermark: false, + grammar: None, + }, + stopping_parameters: ValidStoppingParameters { + ignore_eos_token: false, + max_new_tokens: 1, + stop_sequences: vec![], + }, + top_n_tokens: 0, + }, + response_tx, + span: info_span!("entry"), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }; + (entry, receiver_tx) + } + + #[test] + fn test_append() { + let mut state = State::new(false, 1, None, 0); + let (entry, _guard) = default_entry(); + + assert_eq!(state.next_id, 0); + assert_eq!(state.entries.len(), 0); + + state.append(entry); + + assert_eq!(state.next_id, 1); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0).unwrap(); + assert_eq!(id, 0); + } + + #[test] + fn test_next_batch_empty() { + let mut state = State::new(false, 1, None, 0); + + assert!(state.next_batch(None, None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).is_none()); + } + + #[test] + fn test_next_batch_min_size() { + let mut state = State::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 1); + + let (entry3, _guard3) = default_entry(); + state.append(entry3); + + assert!(state.next_batch(Some(2), None, 2, 2).is_none()); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0).unwrap(); + assert_eq!(id, 2); + } + + #[test] + fn test_next_batch_max_size() { + let mut state = State::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + } + + #[test] + fn test_next_batch_token_budget() { + let mut state = State::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + + let (entry3, _guard3) = default_entry(); + state.append(entry3); + + let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 2); + } + + #[tokio::test] + async fn test_queue_append() { + let queue = Queue::new(false, 1, None, 0); + let (entry, _guard) = default_entry(); + queue.append(entry); + } + + #[tokio::test] + async fn test_queue_next_batch_empty() { + let queue = Queue::new(false, 1, None, 0); + + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); + } + + #[tokio::test] + async fn test_queue_next_batch_min_size() { + let queue = Queue::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + let (entry3, _guard3) = default_entry(); + queue.append(entry3); + + // Not enough requests pending + assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none()); + // Not enough token budget + assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); + // Ok + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); + assert_eq!(entries2.len(), 1); + assert!(entries2.contains_key(&2)); + assert!(entries2.get(&2).unwrap().batch_time.is_some()); + assert_eq!(batch2.id, 1); + assert_eq!(batch2.size, 1); + } + + #[tokio::test] + async fn test_queue_next_batch_max_size() { + let queue = Queue::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + } + + #[tokio::test] + async fn test_queue_next_batch_token_budget() { + let queue = Queue::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + let (entry3, _guard3) = default_entry(); + queue.append(entry3); + + let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + } + + #[tokio::test] + async fn test_queue_next_batch_token_speculate() { + let queue = Queue::new(false, 1, None, 2); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + // Budget of 1 is not enough + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + + let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + } + + #[tokio::test] + async fn test_queue_next_batch_dropped_receiver() { + let queue = Queue::new(false, 1, None, 0); + let (entry, _) = default_entry(); + queue.append(entry); + + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + } +} diff --git a/router/src/infer.rs b/router/src/infer/v2/scheduler.rs similarity index 76% rename from router/src/infer.rs rename to router/src/infer/v2/scheduler.rs index 0410de7d..ba6f520d 100644 --- a/router/src/infer.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,79 +1,46 @@ /// Batching and inference logic -use crate::validation::{Validation, ValidationError}; -use crate::{ - ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, - HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, - TextMessage, Token, +use crate::infer::v2::queue::{Entry, Queue}; +use crate::infer::{ + GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; -use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; +use crate::validation::ValidGenerateRequest; +use crate::{FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::{ - Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens, -}; -use thiserror::Error; +use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_stream::StreamExt; use tracing::{info_span, instrument, Instrument, Span}; -/// Inference struct -#[derive(Clone)] -pub struct Infer { - /// Validation - validation: Validation, +pub(crate) struct SchedulerV2 { /// Request queue queue: Queue, - /// Shared state - shared: Arc, - /// Chat template - chat_template: Option, - /// Inference limit - limit_concurrent_requests: Arc, + /// Notify batcher on queue appends + batching_task_notifier: Arc, } -/// Infer shared state -struct Shared { - /// Batching background Tokio task notifier - batching_task: Notify, -} - -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -impl Infer { +impl SchedulerV2 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, - validation: Validation, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, - max_concurrent_requests: usize, requires_padding: bool, window_size: Option, speculate: u32, generation_health: Arc, - tokenizer_config: HubTokenizerConfig, - processor_config: HubProcessorConfig, ) -> Self { let queue = Queue::new(requires_padding, 16, window_size, speculate); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); + let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( @@ -84,72 +51,31 @@ impl Infer { max_waiting_tokens, max_batch_size, queue.clone(), - shared.clone(), + batching_task_notifier.clone(), generation_health, )); - let chat_template = tokenizer_config - .chat_template - .or(processor_config.chat_template) - .and_then(|t| match t { - ChatTemplateVersions::Single(template) => Some(template), - ChatTemplateVersions::Multiple(templates) => templates - .into_iter() - .find(|t| t.name == "default") - .map(|t| t.template), - }) - .map(|t| { - // .strip() is not supported in minijinja - // .capitalize() is not supported in minijinja but we can use | capitalize - let t = t - .replace(".strip()", " | trim") - .replace(".capitalize()", " | capitalize"); - ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) - }); - - // Inference limit with a semaphore - let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); - Self { - validation, queue, - shared, - chat_template, - limit_concurrent_requests: semaphore, + batching_task_notifier, } } +} - /// Add a new request to the queue and return a stream of InferStreamResponse +impl Scheduler for SchedulerV2 { #[instrument(skip_all)] - pub(crate) async fn generate_stream( + fn schedule( &self, - request: GenerateRequest, + request: ValidGenerateRequest, + permit: OwnedSemaphorePermit, ) -> Result { - // Limit concurrent requests by acquiring a permit from the semaphore - let permit = self - .clone() - .limit_concurrent_requests - .try_acquire_owned() - .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); - tracing::error!("{err}"); - err - })?; - - // Validate request - let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - err - })?; - // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = valid_request.input_length; + let input_length = request.input_length; // Append the request to the queue self.queue.append(Entry { - request: valid_request, + request, response_tx, span: Span::current(), temp_span: None, @@ -159,7 +85,7 @@ impl Infer { // Notify the background task that we have a new entry in the queue that needs // to be batched - self.shared.batching_task.notify_one(); + self.batching_task_notifier.notify_one(); // Return stream Ok(( @@ -168,343 +94,6 @@ impl Infer { UnboundedReceiverStream::new(response_rx), )) } - - /// Tokenizer the input - #[instrument(skip_all)] - pub(crate) async fn tokenize( - &self, - request: GenerateRequest, - ) -> Result, InferError> { - // Tokenize request - let inputs = request.inputs; - let truncate = request.parameters.truncate; - let encoding = self - .validation - .tokenize(inputs, truncate) - .await - .map_err(|err| { - tracing::error!("Tokenization {err}"); - err - })?; - - // Return Encoding - Ok(encoding.map(|(encoding, _)| encoding)) - } - - /// Apply the chat template to the chat request - #[instrument(skip_all)] - pub(crate) fn apply_chat_template( - &self, - messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - self.chat_template - .as_ref() - .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) - .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); - tracing::error!("{e}"); - e - }) - } - - /// Add a new request to the queue and return a InferResponse - #[instrument(skip_all)] - pub(crate) async fn generate( - &self, - request: GenerateRequest, - ) -> Result { - let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); - - // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; - - // Return values - let mut result_prefill = Vec::new(); - let mut result_tokens = Vec::new(); - let mut result_top_tokens = Vec::new(); - let mut result_generated_text = None; - let mut result_start = None; - let mut result_queued = None; - - // Iterate on stream - while let Some(response) = stream.next().await { - match response? { - // Add prefill tokens - InferStreamResponse::Prefill(tokens) => { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - result_prefill = tokens - .ids - .into_iter() - .zip(tokens.logprobs.into_iter()) - .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - } - // Push last token - InferStreamResponse::Intermediate { token, top_tokens } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); - } - // Final message - // Set return values - InferStreamResponse::End { - token, - generated_text, - start, - queued, - top_tokens, - } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); - result_generated_text = Some(generated_text); - result_start = Some(start); - result_queued = Some(queued) - } - } - } - - // Check that we received a `InferStreamResponse::End` message - if let (Some(generated_text), Some(queued), Some(start)) = - (result_generated_text, result_queued, result_start) - { - Ok(InferResponse { - prefill: result_prefill, - _input_length, - tokens: result_tokens, - generated_text, - queued, - start, - top_tokens: if use_top_tokens { - result_top_tokens - } else { - Vec::new() - }, - }) - } else { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); - tracing::error!("{err}"); - Err(err) - } - } - /// Add best_of new requests to the queue and return a InferResponse of the sequence with - /// the highest log probability per token - #[instrument(skip(self, request))] - pub(crate) async fn generate_best_of( - &self, - request: GenerateRequest, - best_of: usize, - ) -> Result<(InferResponse, Vec), InferError> { - // validate best_of parameter separately - let best_of = self.validation.validate_best_of(best_of)?; - - // create multiple generate requests - let mut infer_responses: Vec = - try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; - - // get the sequence with the highest log probability per token - let mut max_index = 0; - let mut max_logprob: f32 = f32::MIN; - - for (i, response) in infer_responses.iter().enumerate() { - // mean logprobs of the generated tokens - let sequence_logprob = response - .tokens - .iter() - .map(|token| token.logprob) - .sum::() - / response.tokens.len() as f32; - - // set best sequence - if sequence_logprob > max_logprob { - max_index = i; - max_logprob = sequence_logprob; - } - } - let best_response = infer_responses.remove(max_index); - Ok((best_response, infer_responses)) - } -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { - let mut env = Box::new(Environment::new()); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token, - eos_token, - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - pub fn apply( - tools: Option>, - tool_choice: Option, - ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) - } } /// Batching logic @@ -512,7 +101,7 @@ impl ToolGrammar { /// /// Batches requests and sends them to the inference server #[allow(clippy::too_many_arguments)] -async fn batching_task( +pub(crate) async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, @@ -520,13 +109,13 @@ async fn batching_task( max_waiting_tokens: usize, max_batch_size: Option, queue: Queue, - shared: Arc, + notifier: Arc, generation_health: Arc, ) { // Infinite loop loop { // Wait for a notification from the Infer struct - shared.batching_task.notified().await; + notifier.notified().await; // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests @@ -792,6 +381,16 @@ fn send_responses( let mut stopped = false; if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + // Send message entry .response_tx @@ -842,7 +441,7 @@ fn send_responses( entry.response_tx.send(Ok(InferStreamResponse::End { token, top_tokens, - generated_text: generated_text.clone(), + generated_text: GeneratedText::from(generated_text.clone()), queued: entry.queue_time, start: entry.batch_time.unwrap(), }))?; @@ -877,64 +476,21 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { }); } -#[derive(Debug)] -pub(crate) enum InferStreamResponse { - // Optional first message - Prefill(Tokens), - // Intermediate messages - Intermediate { - token: Token, - top_tokens: Vec, - }, - // Last message - End { - token: Token, - top_tokens: Vec, - generated_text: GeneratedText, - start: Instant, - queued: Instant, - }, -} +impl From for GeneratedText { + fn from(value: text_generation_client::v2::GeneratedText) -> Self { + let v2_finish_reason = + text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v2_finish_reason { + text_generation_client::v2::FinishReason::Length => FinishReason::Length, + text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + text_generation_client::v2::FinishReason::StopSequence => FinishReason::StopSequence, + }; -#[derive(Debug)] -pub(crate) struct InferResponse { - /// input_length is the input as perceived by the rust tokenizer in the - /// validation pathway. It is redundant with prefill.len() but prefill - /// has data only if the user asked for it. This will always be filled. - pub(crate) _input_length: u32, - pub(crate) prefill: Vec, - pub(crate) tokens: Vec, - pub(crate) generated_text: GeneratedText, - pub(crate) queued: Instant, - pub(crate) start: Instant, - pub(crate) top_tokens: Vec>, -} - -#[derive(Debug, Error)] -pub enum InferError { - #[error("Request failed during generation: {0}")] - GenerationError(String), - #[error("Model is overloaded")] - Overloaded(#[from] TryAcquireError), - #[error("Input validation error: {0}")] - ValidationError(#[from] ValidationError), - #[error("Incomplete generation")] - IncompleteGeneration, - #[error("Template error: {0}")] - TemplateError(#[from] minijinja::Error), - #[error("Tool error: {0}")] - ToolError(String), -} - -impl InferError { - pub(crate) fn error_type(&self) -> &str { - match self { - InferError::GenerationError(_) => "generation", - InferError::Overloaded(_) => "overloaded", - InferError::ValidationError(_) => "validation", - InferError::IncompleteGeneration => "incomplete_generation", - InferError::TemplateError(_) => "template_error", - InferError::ToolError(_) => "tool_error", + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, } } } @@ -1355,11 +911,11 @@ mod tests { chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", input: ChatTemplateInputs { messages: vec![ - TextMessage{ + TextMessage { role: "system".to_string(), content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), }, - TextMessage{ + TextMessage { role: "user".to_string(), content: "How many helicopters can a human eat in one sitting?".to_string(), }, diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs new file mode 100644 index 00000000..4299baf3 --- /dev/null +++ b/router/src/infer/v3/mod.rs @@ -0,0 +1,4 @@ +mod queue; +mod scheduler; + +pub(crate) use scheduler::SchedulerV3; diff --git a/router/src/queue.rs b/router/src/infer/v3/queue.rs similarity index 90% rename from router/src/queue.rs rename to router/src/infer/v3/queue.rs index 40692ffc..b926f329 100644 --- a/router/src/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -1,12 +1,14 @@ -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::ValidGenerateRequest; +use crate::infer::{InferError, InferStreamResponse}; +use crate::validation::{ + ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; -use text_generation_client::ChunksToString; -use text_generation_client::Input; -use text_generation_client::{Batch, Request}; +use text_generation_client::v3::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use text_generation_client::{ChunksToString, Input}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -57,7 +59,6 @@ impl Queue { Self { queue_sender } } - /// Append an entry to the queue #[instrument(skip_all)] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state @@ -280,13 +281,17 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, + inputs: entry.request.inputs.chunks_to_string(), input_chunks: Some(Input { chunks: entry.request.inputs.clone(), }), - inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), + parameters: Some(NextTokenChooserParameters::from( + entry.request.parameters.clone(), + )), + stopping_parameters: Some(StoppingCriteriaParameters::from( + entry.request.stopping_parameters.clone(), + )), top_n_tokens: entry.request.top_n_tokens, }); // Set batch_time @@ -355,12 +360,46 @@ enum QueueCommand { }, } +impl From for NextTokenChooserParameters { + fn from(value: ValidParameters) -> Self { + let (grammar, grammar_type) = match value.grammar { + None => (String::new(), GrammarType::None), + + Some(grammar) => match grammar { + ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json), + ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex), + }, + }; + + Self { + temperature: value.temperature, + top_k: value.top_k, + top_p: value.top_p, + typical_p: value.typical_p, + do_sample: value.do_sample, + seed: value.seed, + repetition_penalty: value.repetition_penalty, + frequency_penalty: value.frequency_penalty, + watermark: value.watermark, + grammar, + grammar_type: grammar_type.into(), + } + } +} + +impl From for StoppingCriteriaParameters { + fn from(value: ValidStoppingParameters) -> Self { + Self { + max_new_tokens: value.max_new_tokens, + stop_sequences: value.stop_sequences, + ignore_eos_token: value.ignore_eos_token, + } + } +} + #[cfg(test)] mod tests { use super::*; - use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, - }; use tracing::info_span; fn default_entry() -> ( @@ -375,7 +414,7 @@ mod tests { input_length: 0, truncate: 0, decoder_input_details: false, - parameters: NextTokenChooserParameters { + parameters: ValidParameters { temperature: 0.0, top_k: 0, top_p: 0.0, @@ -385,10 +424,9 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, + grammar: None, }, - stopping_parameters: StoppingCriteriaParameters { + stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, max_new_tokens: 1, stop_sequences: vec![], diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs new file mode 100644 index 00000000..257d191f --- /dev/null +++ b/router/src/infer/v3/scheduler.rs @@ -0,0 +1,1177 @@ +/// Batching and inference logic +use crate::infer::v3::queue::{Entry, Queue}; +use crate::infer::{ + GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, +}; +use crate::validation::ValidGenerateRequest; +use crate::{FinishReason, PrefillToken, Token}; +use nohash_hasher::IntMap; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::ClientError; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub(crate) struct SchedulerV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, +} + +impl SchedulerV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + generation_health: Arc, + ) -> Self { + let queue = Queue::new(requires_padding, 16, window_size, speculate); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + generation_health, + )); + + Self { + queue, + batching_task_notifier, + } + } +} + +impl Scheduler for SchedulerV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + permit: OwnedSemaphorePermit, + ) -> Result { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = request.input_length; + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, + generation_health: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size", batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + } else { + metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + prefill(&mut client, new_batch, &mut new_entries, &generation_health) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries, &generation_health) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size", 0.0); + metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, + generation_health: &Arc, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); + + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + // Update health + generation_health.store(false, Ordering::SeqCst); + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, + generation_health: &Arc, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); + + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + } + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + generation_health.store(false, Ordering::SeqCst); + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: text_generation_client::v3::GeneratedText) -> Self { + let v3_finish_reason = + text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + text_generation_client::v3::FinishReason::Length => FinishReason::Length, + text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} + +// tests +#[cfg(test)] +mod tests { + use crate::infer::raise_exception; + use crate::{ChatTemplateInputs, TextMessage}; + use minijinja::Environment; + + #[test] + fn test_chat_template() { + let env = Environment::new(); + + let source = r#" + {% for message in messages %} + {% if message['role'] == 'system' %} + {% if message['content']%} + {{'### System:\n' + message['content']+'\n\n'}} + {% endif %} + {% elif message['role'] == 'user' %} + {{'### User:\n' + message['content']+'\n\n'}} + {% elif message['role'] == 'assistant' %} + {{'### Assistant:\n' + message['content']}} + {% endif %} + {% if loop.last and add_generation_prompt %} + {{ '### Assistant:\n' }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + + assert_eq!( + result, + "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" + ); + } + + #[test] + fn test_chat_template_invalid_with_raise() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {{ bos_token }} + {% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token}} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "Hi again!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); + + match result { + Ok(_) => panic!("Should have failed"), + Err(e) => { + assert_eq!( + e.detail().unwrap(), + "Conversation roles must alternate user/assistant/user/assistant/..." + ); + } + } + } + + #[test] + fn test_chat_template_valid_with_raise() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {{ bos_token }} + {% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token}} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); + } + + #[test] + fn test_chat_template_valid_with_add_generation_prompt() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} + {% endif %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); + } + + struct ChatTemplateTestItem { + name: &'static str, + chat_template: &'static str, + input: ChatTemplateInputs<'static>, + target: &'static str, + } + + #[test] + fn test_many_chat_templates() { + let example_chat = vec![ + TextMessage { + role: "user".to_string(), + content: "Hello, how are you?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "I'm doing great. How can I help you today?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "I'd like to show off how chat templating works!".to_string(), + }, + ]; + + let example_chat_with_system = [TextMessage { + role: "system".to_string(), + content: "You are a friendly chatbot who always responds in the style of a pirate" + .to_string(), + }] + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); + + let test_default_templates = vec![ + ChatTemplateTestItem { + name: "_base", + chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "blenderbot", + chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "blenderbot_small", + chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "bloom", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "gpt_neox", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, + ChatTemplateTestItem { + name: "gpt2", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, + ChatTemplateTestItem { + name: "llama", + // NOTE: the `.strip()` has been replaced with `| trim` in the following template + chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content | trim + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "whisper", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, + ]; + + #[allow(unused_variables)] // name is unused + for ChatTemplateTestItem { + name, + chat_template, + input, + target, + } in test_default_templates + { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + let tmpl = env.template_from_str(chat_template); + let result = tmpl.unwrap().render(input).unwrap(); + assert_eq!(result, target); + } + + let test_custom_templates = vec![ + ChatTemplateTestItem { + name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)", + chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHello, how are you?<|assistant|>\nI'm doing great. How can I help you today?<|user|>\nI'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)", + chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "system".to_string(), + content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "How many helicopters can a human eat in one sitting?".to_string(), + }, + ], + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>", + }, + ChatTemplateTestItem { + name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", + chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "mistralai/Mistral-7B-Instruct-v0.1", + chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "mistralai/Mixtral-8x7B-Instruct-v0.1", + chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", + chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "openchat/openchat-3.5-0106", + // `.title()` has been replaced with `| upper` in the following template + chat_template: "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + (message['role'] | title) + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>", + }, + ChatTemplateTestItem { + name: "upstage/SOLAR-10.7B-Instruct-v1.0", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "codellama/CodeLlama-70b-Instruct-hf", + // NOTE: `.strip()` has been replaced with `| trim` in the following template + chat_template: "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\\n\\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\\nDestination: user\\n\\n '}}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n ", + }, + ChatTemplateTestItem { + name: "Deci/DeciLM-7B-instruct", + chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "Qwen/Qwen1.5-72B-Chat", + chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "deepseek-ai/deepseek-llm-7b-chat", + chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\\n\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|begin▁of▁sentence|>"), + eos_token: Some("<|end▁of▁sentence|>"), + ..Default::default() + }, + target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n", + }, + ChatTemplateTestItem { + name: "h2oai/h2o-danube-1.8b-chat", + chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "internlm/internlm2-chat-7b", + chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "TheBloke/deepseek-coder-33B-instruct-AWQ", + chat_template: "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|begin▁of▁sentence|>"), + eos_token: Some("<|EOT|>"), + ..Default::default() + }, + target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n", + }, + ChatTemplateTestItem { + name: "ericzzz/falcon-rw-1b-chat", + // `.strip()` has been replaced with `| trim` in the following template + chat_template: "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'] | trim }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|endoftext|>"), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "abacusai/Smaug-34B-v0.1", + chat_template: "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "maywell/Synatra-Mixtral-8x7B", + chat_template: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "deepseek-ai/deepseek-coder-33b-instruct", + chat_template: "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|begin▁of▁sentence|>"), + eos_token: Some(""), + ..Default::default() + }, + target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n", + }, + // NOT INCLUDED + // - meetkai/functionary-medium-v3.2 + // - fireworks-ai/firefunction-v1 + // https://github + ChatTemplateTestItem { + name: "maywell/PiVoT-MoE", + chat_template: "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", + }, + ]; + + #[allow(unused_variables)] // name is unused + for ChatTemplateTestItem { + name, + chat_template, + input, + target, + } in test_custom_templates + { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + // trim all the whitespace + let chat_template = chat_template + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&chat_template); + let result = tmpl.unwrap().render(input).unwrap(); + assert_eq!(result, target); + } + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 9b3283df..b6902c49 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,27 +1,14 @@ -pub mod config; -mod health; /// Text Generation Inference Webserver +pub mod config; mod infer; -mod queue; pub mod server; mod validation; -use infer::{Infer, InferError, InferStreamResponse}; -use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; -use tokio::sync::OwnedSemaphorePermit; -use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; use utoipa::ToSchema; use validation::Validation; -/// Type alias for generation responses -pub(crate) type GenerateStreamResponse = ( - OwnedSemaphorePermit, - u32, // input_length - UnboundedReceiverStream>, -); - #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")] @@ -158,7 +145,7 @@ pub struct Info { #[schema(example = "4")] pub max_stop_sequences: usize, #[schema(example = "1024")] - pub max_input_length: usize, + pub max_input_tokens: usize, #[schema(example = "2048")] pub max_total_tokens: usize, #[schema(example = "1.2")] @@ -1087,7 +1074,7 @@ pub struct SimpleToken { stop: usize, } -#[derive(Serialize, ToSchema)] +#[derive(Debug, Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] pub(crate) enum FinishReason { diff --git a/router/src/main.rs b/router/src/main.rs index b526367c..c4203dbc 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -12,7 +12,6 @@ use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; -use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::config::Config; use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use thiserror::Error; @@ -315,59 +314,6 @@ async fn main() -> Result<(), RouterError> { Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", }; - // Instantiate sharded client from the master unix socket - let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(RouterError::Connection)?; - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(RouterError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_supported_batch_total_tokens = match sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(RouterError::Warmup)? - { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); - tracing::warn!("Model does not support automatic max batch total tokens"); - max_batch_total_tokens - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); - } - - max_supported_batch_total_tokens - } - }; - tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); - tracing::info!("Connected"); - // Determine the server port based on the feature and environment variable. let port = if cfg!(feature = "google") { std::env::var("AIP_HTTP_PORT") @@ -387,8 +333,8 @@ async fn main() -> Result<(), RouterError> { // Run server server::run( + master_shard_uds_path, model_info, - shard_info, compat_return_full_text, max_concurrent_requests, max_best_of, @@ -398,10 +344,9 @@ async fn main() -> Result<(), RouterError> { max_total_tokens, waiting_served_ratio, max_batch_prefill_tokens, - max_supported_batch_total_tokens, + max_batch_total_tokens, max_waiting_tokens, max_batch_size, - sharded_client, tokenizer, config, validation_workers, @@ -557,16 +502,8 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option) -> Result<(), (StatusCode, Json)> { +async fn health( + mut health: Extension, +) -> Result<(), (StatusCode, Json)> { match health.check().await { true => Ok(()), false => Err(( @@ -213,9 +218,7 @@ async fn generate_internal( BestOfSequence { generated_text: output_text, - finish_reason: FinishReason::from( - response.generated_text.finish_reason, - ), + finish_reason: response.generated_text.finish_reason, generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, @@ -227,7 +230,7 @@ async fn generate_internal( }); Some(Details { - finish_reason: FinishReason::from(response.generated_text.finish_reason), + finish_reason: response.generated_text.finish_reason, generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, @@ -468,7 +471,7 @@ async fn generate_stream_internal( // Token details let details = match details { true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), + finish_reason: generated_text.finish_reason, generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, }), @@ -556,38 +559,38 @@ async fn generate_stream_internal( /// Generate tokens #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/v1/completions", - request_body = CompletionRequest, - responses( - (status = 200, description = "Generated Chat Completion", - content( - ("application/json" = Completion), - ("text/event-stream" = CompletionCompleteChunk), - )), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/v1/completions", +request_body = CompletionRequest, +responses( +(status = 200, description = "Generated Chat Completion", +content( +("application/json" = Completion), +("text/event-stream" = CompletionCompleteChunk), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] #[instrument( - skip_all, - fields( - // parameters = ? req.parameters, - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) - )] +skip_all, +fields( +// parameters = ? req.parameters, +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) +)] async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, @@ -961,38 +964,38 @@ async fn completions( /// Generate tokens #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/v1/chat/completions", - request_body = ChatRequest, - responses( - (status = 200, description = "Generated Chat Completion", - content( - ("application/json" = ChatCompletion), - ("text/event-stream" = ChatCompletionChunk), - )), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/v1/chat/completions", +request_body = ChatRequest, +responses( +(status = 200, description = "Generated Chat Completion", +content( +("application/json" = ChatCompletion), +("text/event-stream" = ChatCompletionChunk), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] #[instrument( - skip_all, - fields( - // parameters = ? req.parameters, - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) - )] +skip_all, +fields( +// parameters = ? req.parameters, +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) +)] async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, @@ -1217,22 +1220,22 @@ async fn chat_completions( /// Generate tokens from Vertex request #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/vertex", - request_body = VertexRequest, - responses( - (status = 200, description = "Generated Text", body = VertexResponse), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/vertex", +request_body = VertexRequest, +responses( +(status = 200, description = "Generated Text", body = VertexResponse), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] #[instrument( skip_all, fields( @@ -1310,16 +1313,16 @@ async fn vertex_compatibility( /// Tokenize inputs #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/tokenize", - request_body = GenerateRequest, - responses( - (status = 200, description = "Tokenized ids", body = TokenizeResponse), - (status = 404, description = "No tokenizer found", body = ErrorResponse, - example = json ! ({"error": "No fast tokenizer available"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/tokenize", +request_body = GenerateRequest, +responses( +(status = 200, description = "Tokenized ids", body = TokenizeResponse), +(status = 404, description = "No tokenizer found", body = ErrorResponse, +example = json ! ({"error": "No fast tokenizer available"})), +) +)] #[instrument(skip_all)] async fn tokenize( Extension(infer): Extension, @@ -1372,21 +1375,20 @@ pub(crate) struct ComputeType(String); /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( + master_shard_uds_path: String, model_info: HubModelInfo, - shard_info: ShardInfo, compat_return_full_text: bool, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, - max_input_length: usize, + max_input_tokens: usize, max_total_tokens: usize, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, + max_batch_total_tokens: Option, max_waiting_tokens: usize, max_batch_size: Option, - client: ShardedClient, tokenizer: Option, config: Option, validation_workers: usize, @@ -1400,7 +1402,7 @@ pub async fn run( messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, -) -> Result<(), axum::BoxError> { +) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] #[openapi( @@ -1470,6 +1472,141 @@ pub async fn run( struct ApiDoc; // Create state + + // Open connection, get model info and warmup + let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( + Arc, + HealthCheck, + ShardInfo, + u32, + ) = { + // Helper function to check both v2 and v3 + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), + ); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(WebServerError::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let generation_health = Arc::new(AtomicBool::new(false)); + + match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { + Ok(mut sharded_client) => { + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(WebServerError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(WebServerError::Warmup)?, + )?; + + let health_ext = + HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); + let scheduler = Arc::new(SchedulerV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + generation_health, + )); + tracing::info!("Using scheduler V3"); + + (scheduler, health_ext, shard_info, max_batch_total_tokens) + } + Err(_) => { + let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(WebServerError::Connection)?; + + // server is running on v2 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(WebServerError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(WebServerError::Warmup)?, + )?; + + let health_ext = + HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); + let scheduler = Arc::new(SchedulerV2::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + generation_health, + )); + tracing::info!("Using scheduler V2"); + + (scheduler, health_ext, shard_info, max_batch_total_tokens) + } + } + }; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + let validation = Validation::new( validation_workers, tokenizer, @@ -1477,25 +1614,15 @@ pub async fn run( max_best_of, max_stop_sequences, max_top_n_tokens, - max_input_length, + max_input_tokens, max_total_tokens, grammar_support, ); - let generation_health = Arc::new(AtomicBool::new(false)); - let health_ext = Health::new(client.clone(), generation_health.clone()); + let infer = Infer::new( - client, + scheduler, validation, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, max_concurrent_requests, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, tokenizer_config, processor_config, ); @@ -1514,7 +1641,7 @@ pub async fn run( // Input Length buckets let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length")); let input_length_buckets: Vec = (0..100) - .map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64) + .map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64) .collect(); // Generated tokens buckets let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens")); @@ -1568,7 +1695,7 @@ pub async fn run( max_concurrent_requests, max_best_of, max_stop_sequences, - max_input_length, + max_input_tokens, max_total_tokens, waiting_served_ratio, max_batch_total_tokens, @@ -1664,6 +1791,8 @@ pub async fn run( .layer(OtelAxumLayer::default()) .layer(cors_layer); + tracing::info!("Connected"); + if ngrok { #[cfg(feature = "ngrok")] { @@ -1686,7 +1815,8 @@ pub async fn run( let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) - .await?; + .await + .map_err(|err| WebServerError::Axum(Box::new(err)))?; } Ok(()) } @@ -1719,17 +1849,6 @@ async fn shutdown_signal() { opentelemetry::global::shutdown_tracer_provider(); } -impl From for FinishReason { - fn from(finish_reason: i32) -> Self { - let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap(); - match finish_reason { - text_generation_client::FinishReason::Length => FinishReason::Length, - text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence, - } - } -} - /// Convert to Axum supported formats impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { @@ -1762,3 +1881,19 @@ impl From for Event { .unwrap() } } + +#[derive(Debug, Error)] +pub enum WebServerError { + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), + #[error("Axum error: {0}")] + Axum(#[from] axum::BoxError), +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 863bb99b..bb9ad318 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,20 +1,16 @@ -use crate::config::Config; /// Payload validation logic +use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use base64::{engine::general_purpose::STANDARD, Engine}; +use image::{io::Reader as ImageReader, ImageFormat}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; -use text_generation_client::{ - Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters, - StoppingCriteriaParameters, -}; +use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; -// use tokenizers::TruncationDirection; -use base64::{engine::general_purpose::STANDARD, Engine}; -use image::{io::Reader as ImageReader, ImageFormat}; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; @@ -173,10 +169,6 @@ impl Validation { // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { input_length = input_length.saturating_sub(max_new_tokens as usize); - // return Err(ValidationError::MaxNewTokens( - // self.max_total_tokens - self.max_input_length, - // max_new_tokens, - // )); } Ok(( @@ -327,13 +319,13 @@ impl Validation { // compiler and use that to build the FSM here. // Validate grammar and unpack the grammar and type for the proto message - let (grammar, grammar_type) = match grammar { + let grammar = match grammar { Some(grammar) => { // Ensure that grammar is not set if it's not supported if self.disable_grammar_support { return Err(ValidationError::Grammar); } - match grammar { + let valid_grammar = match grammar { GrammarType::Json(json) => { let json = match json { // if value is a string, we need to parse it again to make sure its @@ -350,20 +342,20 @@ impl Validation { .compile(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; - ( - // Serialize json to string + // Serialize json to string + ValidGrammar::Json( serde_json::to_string(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, - ProtoGrammarType::Json.into(), ) } - GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), - } + GrammarType::Regex(regex) => ValidGrammar::Regex(regex), + }; + Some(valid_grammar) } - None => (String::new(), ProtoGrammarType::None.into()), + None => None, }; - let parameters = NextTokenChooserParameters { + let parameters = ValidParameters { temperature, repetition_penalty, frequency_penalty, @@ -374,9 +366,8 @@ impl Validation { seed, watermark, grammar, - grammar_type, }; - let stopping_parameters = StoppingCriteriaParameters { + let stopping_parameters = ValidStoppingParameters { max_new_tokens, stop_sequences, ignore_eos_token: false, @@ -458,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option { _ => None, } } + fn format_to_mimetype(format: ImageFormat) -> String { match format { ImageFormat::Png => "image/png", @@ -636,14 +628,55 @@ type TokenizerRequest = ( Span, ); +#[derive(Debug, Clone)] +pub(crate) enum ValidGrammar { + Json(String), + Regex(String), +} + +#[derive(Debug, Clone)] +pub(crate) struct ValidParameters { + /// / exponential scaling output probability distribution + pub temperature: f32, + /// / restricting to the k highest probability elements + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + pub typical_p: f32, + /// / apply sampling on the logits + pub do_sample: bool, + /// / random seed for sampling + pub seed: u64, + /// / repetition penalty + pub repetition_penalty: f32, + /// / frequency penalty + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + pub watermark: bool, + /// / grammar (applied if not empty) + pub grammar: Option, +} + +#[derive(Debug, Clone)] +pub(crate) struct ValidStoppingParameters { + /// / Maximum number of generated tokens + pub max_new_tokens: u32, + /// / Optional stopping sequences + pub stop_sequences: Vec, + /// / Ignore end of sequence token + /// / used for benchmarking + pub ignore_eos_token: bool, +} + #[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, - pub parameters: NextTokenChooserParameters, - pub stopping_parameters: StoppingCriteriaParameters, + pub parameters: ValidParameters, + pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, } diff --git a/server/Makefile b/server/Makefile index 32d01709..312f14df 100644 --- a/server/Makefile +++ b/server/Makefile @@ -12,8 +12,8 @@ gen-server: # Compile protos pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir mkdir text_generation_server/pb || true - python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \ - --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto + python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ + --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py From 648dd7b8e19b7bec469a09762373c488a2d995ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 4 Jun 2024 19:37:49 +0200 Subject: [PATCH 025/340] Support GPTQ models with column-packed up/gate tensor (#2006) # What does this PR do? The GPTQ code path for column-packed packed tensors assumed that this is always a QKV matrix. However, models (e.g. Phi-3) can also have column-packed MLP up/gate matrices. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../text_generation_server/utils/weights.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2dfd80bf..71d67d82 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -121,24 +121,30 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str): + def _get_qweight(self, name: str, blocks: int): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] - assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" - single_size = total_size // 3 + assert ( + total_size % blocks == 0 + ), f"Prepacked quantized matrix is not divisible by {blocks}" + single_size = total_size // blocks world_size = self.process_group.size() rank = self.process_group.rank() assert ( single_size % world_size == 0 - ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards" block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size - q = slice_[:, start:stop] - k = slice_[:, start + single_size : stop + single_size] - v = slice_[:, start + 2 * single_size : stop + 2 * single_size] - weight = torch.cat([q, k, v], dim=1) + + weights = [] + for block in range(blocks): + weights.append( + slice_[:, start + block * single_size : stop + block * single_size] + ) + + weight = torch.cat(weights, dim=1) weight = weight.to(device=self.device) return weight @@ -157,7 +163,7 @@ class Weights: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight") + qweight = self._get_qweight(f"{prefix}.qweight", blocks) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -165,8 +171,8 @@ class Weights: bits, groupsize, _, quant_method = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros") - scales = self._get_qweight(f"{prefix}.scales") + qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) + scales = self._get_qweight(f"{prefix}.scales", blocks) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq": From ed8913535b7ed59c710aa317b79e1927b32da7e8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 4 Jun 2024 19:38:46 +0200 Subject: [PATCH 026/340] Making `make install` work better by default. (#2004) # What does this PR do? Making `make install` a much better sane default to start local dev environments. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- Cargo.toml | 4 + Makefile | 4 + router/client/build.rs | 6 +- server/Makefile | 12 +- server/Makefile-flash-att | 22 +- server/Makefile-flash-att-v2 | 41 ++- server/Makefile-vllm | 43 +-- server/poetry.lock | 504 +++++++++++++++++++---------------- server/pyproject.toml | 10 +- 9 files changed, 347 insertions(+), 299 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8b03a3ee..88497390 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,10 @@ tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } [profile.release] +incremental = true + +[profile.release-binary] +inherits = "release" debug = 1 incremental = true lto = "fat" diff --git a/Makefile b/Makefile index 96e67e2b..27b5efcf 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,10 @@ router-dev: rust-tests: install-router install-launcher cargo test +install-integration-tests: + cd integration-tests && pip install -r requirements.txt + cd clients/python && pip install . + integration-tests: install-integration-tests pytest -s -vv -m "not private" integration-tests diff --git a/router/client/build.rs b/router/client/build.rs index bcfab74f..a7ade9b0 100644 --- a/router/client/build.rs +++ b/router/client/build.rs @@ -13,7 +13,11 @@ fn main() -> Result<(), Box> { .out_dir("src/v2/pb") .include_file("mod.rs") .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) - .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + .map_err(|e| match e.kind(){ + std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")}, + std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")}, + e => {e} + }).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); fs::create_dir_all("src/v3/pb").unwrap_or(()); let mut config = prost_build::Config::new(); diff --git a/server/Makefile b/server/Makefile index 312f14df..5257b876 100644 --- a/server/Makefile +++ b/server/Makefile @@ -10,18 +10,26 @@ unit-tests: gen-server: # Compile protos - pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir + pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir mkdir text_generation_server/pb || true python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py -install: gen-server +install-server: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt pip install -e ".[bnb, accelerate, quantize, peft, outlines]" + +install: install-cuda + echo "Installed server" + +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention + +install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm + run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index ffa304aa..5570863b 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,16 +1,14 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec -flash-attention: - # Clone flash attention - pip install -U packaging ninja --no-cache-dir - git clone https://github.com/HazyResearch/flash-attention.git - -build-flash-attention: flash-attention - cd flash-attention && git fetch && git checkout $(flash_att_commit) - cd flash-attention && python setup.py build - cd flash-attention/csrc/rotary && python setup.py build - cd flash-attention/csrc/layer_norm && python setup.py build +build-flash-attention: + if [ ! -d 'flash-attention' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/HazyResearch/flash-attention.git && \ + cd flash-attention && git fetch && git checkout $(flash_att_commit) && \ + MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build; \ + fi install-flash-attention: build-flash-attention - pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true - cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install + if [ ! -d 'flash-attention' ]; then \ + cd flash-attntion && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install; \ + fi diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index bbff0090..b67803fe 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,29 +1,24 @@ -flash_att_v2_commit_cuda := v2.5.8 +flash_att_v2_commit_cuda := v2.5.9.post1 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 +build-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) -flash-attention-v2-cuda: - # Clone flash attention - pip install -U packaging ninja --no-cache-dir - git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 +install-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) -build-flash-attention-v2-cuda: flash-attention-v2-cuda - cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) - cd flash-attention-v2 && git submodule update --init --recursive - cd flash-attention-v2 && python setup.py build - -install-flash-attention-v2-cuda: build-flash-attention-v2-cuda - cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install - -flash-attention-v2-rocm: - # Clone flash attention - pip install -U packaging ninja --no-cache-dir - git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 - -build-flash-attention-v2-rocm: flash-attention-v2-rocm - cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) - cd flash-attention-v2 && git submodule update --init --recursive - cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build +build-flash-attention-v2-rocm: + if [ ! -d 'flash-attention-v2' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ + git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi install-flash-attention-v2-rocm: build-flash-attention-v2-rocm - cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install + if [ ! -d 'flash-attention-v2' ]; then \ + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install; \ + fi diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 62fa413f..de3b4611 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,25 +1,26 @@ -vllm-cuda: - # Clone vllm - pip install -U ninja packaging --no-cache-dir - git clone https://github.com/Narsil/vllm.git vllm - -build-vllm-cuda: vllm-cuda - cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa - cd vllm && python setup.py build - +build-vllm-cuda: + if [ ! -d 'vllm' ]; then \ + pip install -U ninja packaging --no-cache-dir && \ + git clone https://github.com/Narsil/vllm.git vllm &&\ + cd vllm && \ + git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa &&\ + python setup.py build; \ + fi install-vllm-cuda: build-vllm-cuda - pip uninstall vllm -y || true - cd vllm && python setup.py install + if [ ! -d 'vllm' ]; then \ + cd vllm && pip install -e .; \ + fi -vllm-rocm: - # Clone vllm - pip install -U ninja packaging --no-cache-dir - git clone https://github.com/fxmarty/rocm-vllm.git vllm - -build-vllm-rocm: vllm-rocm - cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 - cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install +build-vllm-rocm: + if [ ! -d 'vllm' ]; then \ + pip install -U ninja packaging --no-cache-dir && \ + git clone https://github.com/fxmarty/rocm-vllm.git vllm && \ + cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 && \ + PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi install-vllm-rocm: build-vllm-rocm - pip uninstall vllm -y || true - cd vllm && python setup.py install + if [ ! -d 'vllm' ]; then \ + cd vllm && \ + PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .; \ + fi diff --git a/server/poetry.lock b/server/poetry.lock index 2bf4ca22..4984978a 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "accelerate" @@ -181,17 +181,6 @@ tests = ["attrs[tests-no-zope]", "zope-interface"] tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] -[[package]] -name = "backoff" -version = "2.2.1" -description = "Function decoration for backoff and retry" -optional = false -python-versions = ">=3.7,<4.0" -files = [ - {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, - {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, -] - [[package]] name = "bitsandbytes" version = "0.43.1" @@ -213,13 +202,13 @@ test = ["scipy"] [[package]] name = "certifi" -version = "2024.2.2" +version = "2024.6.2" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, - {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, + {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, + {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, ] [[package]] @@ -570,13 +559,13 @@ files = [ [[package]] name = "fsspec" -version = "2024.5.0" +version = "2024.6.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"}, - {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"}, + {file = "fsspec-2024.6.0-py3-none-any.whl", hash = "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee"}, + {file = "fsspec-2024.6.0.tar.gz", hash = "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2"}, ] [package.dependencies] @@ -588,6 +577,7 @@ adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] dev = ["pre-commit", "ruff"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] @@ -611,17 +601,17 @@ tqdm = ["tqdm"] [[package]] name = "googleapis-common-protos" -version = "1.63.0" +version = "1.63.1" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" files = [ - {file = "googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e"}, - {file = "googleapis_common_protos-1.63.0-py2.py3-none-any.whl", hash = "sha256:ae45f75702f7c08b541f750854a678bd8f534a1a6bace6afe975f1d0a82d6632"}, + {file = "googleapis-common-protos-1.63.1.tar.gz", hash = "sha256:c6442f7a0a6b2a80369457d79e6672bb7dcbaab88e0848302497e3ec80780a6a"}, + {file = "googleapis_common_protos-1.63.1-py2.py3-none-any.whl", hash = "sha256:0e1c2cdfcbc354b76e4a211a35ea35d6926a835cba1377073c4861db904a1877"}, ] [package.dependencies] -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] @@ -645,61 +635,61 @@ testing = ["protobuf (>=4.21.9)"] [[package]] name = "grpcio" -version = "1.64.0" +version = "1.64.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.64.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3b09c3d9de95461214a11d82cc0e6a46a6f4e1f91834b50782f932895215e5db"}, - {file = "grpcio-1.64.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7e013428ab472892830287dd082b7d129f4d8afef49227a28223a77337555eaa"}, - {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:02cc9cc3f816d30f7993d0d408043b4a7d6a02346d251694d8ab1f78cc723e7e"}, - {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f5de082d936e0208ce8db9095821361dfa97af8767a6607ae71425ac8ace15c"}, - {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b7bf346391dffa182fba42506adf3a84f4a718a05e445b37824136047686a1"}, - {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b2cbdfba18408389a1371f8c2af1659119e1831e5ed24c240cae9e27b4abc38d"}, - {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:aca4f15427d2df592e0c8f3d38847e25135e4092d7f70f02452c0e90d6a02d6d"}, - {file = "grpcio-1.64.0-cp310-cp310-win32.whl", hash = "sha256:7c1f5b2298244472bcda49b599be04579f26425af0fd80d3f2eb5fd8bc84d106"}, - {file = "grpcio-1.64.0-cp310-cp310-win_amd64.whl", hash = "sha256:73f84f9e5985a532e47880b3924867de16fa1aa513fff9b26106220c253c70c5"}, - {file = "grpcio-1.64.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2a18090371d138a57714ee9bffd6c9c9cb2e02ce42c681aac093ae1e7189ed21"}, - {file = "grpcio-1.64.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:59c68df3a934a586c3473d15956d23a618b8f05b5e7a3a904d40300e9c69cbf0"}, - {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b52e1ec7185512103dd47d41cf34ea78e7a7361ba460187ddd2416b480e0938c"}, - {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d598b5d5e2c9115d7fb7e2cb5508d14286af506a75950762aa1372d60e41851"}, - {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01615bbcae6875eee8091e6b9414072f4e4b00d8b7e141f89635bdae7cf784e5"}, - {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0b2dfe6dcace264807d9123d483d4c43274e3f8c39f90ff51de538245d7a4145"}, - {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7f17572dc9acd5e6dfd3014d10c0b533e9f79cd9517fc10b0225746f4c24b58e"}, - {file = "grpcio-1.64.0-cp311-cp311-win32.whl", hash = "sha256:6ec5ed15b4ffe56e2c6bc76af45e6b591c9be0224b3fb090adfb205c9012367d"}, - {file = "grpcio-1.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:597191370951b477b7a1441e1aaa5cacebeb46a3b0bd240ec3bb2f28298c7553"}, - {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, - {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, - {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, - {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, - {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, - {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, - {file = "grpcio-1.64.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c46fb6bfca17bfc49f011eb53416e61472fa96caa0979b4329176bdd38cbbf2a"}, - {file = "grpcio-1.64.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3d2004e85cf5213995d09408501f82c8534700d2babeb81dfdba2a3bff0bb396"}, - {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6d5541eb460d73a07418524fb64dcfe0adfbcd32e2dac0f8f90ce5b9dd6c046c"}, - {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f279ad72dd7d64412e10f2443f9f34872a938c67387863c4cd2fb837f53e7d2"}, - {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85fda90b81da25993aa47fae66cae747b921f8f6777550895fb62375b776a231"}, - {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a053584079b793a54bece4a7d1d1b5c0645bdbee729215cd433703dc2532f72b"}, - {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:579dd9fb11bc73f0de061cab5f8b2def21480fd99eb3743ed041ad6a1913ee2f"}, - {file = "grpcio-1.64.0-cp38-cp38-win32.whl", hash = "sha256:23b6887bb21d77649d022fa1859e05853fdc2e60682fd86c3db652a555a282e0"}, - {file = "grpcio-1.64.0-cp38-cp38-win_amd64.whl", hash = "sha256:753cb58683ba0c545306f4e17dabf468d29cb6f6b11832e1e432160bb3f8403c"}, - {file = "grpcio-1.64.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:2186d76a7e383e1466e0ea2b0febc343ffeae13928c63c6ec6826533c2d69590"}, - {file = "grpcio-1.64.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0f30596cdcbed3c98024fb4f1d91745146385b3f9fd10c9f2270cbfe2ed7ed91"}, - {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:d9171f025a196f5bcfec7e8e7ffb7c3535f7d60aecd3503f9e250296c7cfc150"}, - {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf4c8daed18ae2be2f1fc7d613a76ee2a2e28fdf2412d5c128be23144d28283d"}, - {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3550493ac1d23198d46dc9c9b24b411cef613798dc31160c7138568ec26bc9b4"}, - {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3161a8f8bb38077a6470508c1a7301cd54301c53b8a34bb83e3c9764874ecabd"}, - {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e8fabe2cc57a369638ab1ad8e6043721014fdf9a13baa7c0e35995d3a4a7618"}, - {file = "grpcio-1.64.0-cp39-cp39-win32.whl", hash = "sha256:31890b24d47b62cc27da49a462efe3d02f3c120edb0e6c46dcc0025506acf004"}, - {file = "grpcio-1.64.0-cp39-cp39-win_amd64.whl", hash = "sha256:5a56797dea8c02e7d3a85dfea879f286175cf4d14fbd9ab3ef2477277b927baa"}, - {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, + {file = "grpcio-1.64.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:55697ecec192bc3f2f3cc13a295ab670f51de29884ca9ae6cd6247df55df2502"}, + {file = "grpcio-1.64.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3b64ae304c175671efdaa7ec9ae2cc36996b681eb63ca39c464958396697daff"}, + {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:bac71b4b28bc9af61efcdc7630b166440bbfbaa80940c9a697271b5e1dabbc61"}, + {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c024ffc22d6dc59000faf8ad781696d81e8e38f4078cb0f2630b4a3cf231a90"}, + {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7cd5c1325f6808b8ae31657d281aadb2a51ac11ab081ae335f4f7fc44c1721d"}, + {file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0a2813093ddb27418a4c99f9b1c223fab0b053157176a64cc9db0f4557b69bd9"}, + {file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2981c7365a9353f9b5c864595c510c983251b1ab403e05b1ccc70a3d9541a73b"}, + {file = "grpcio-1.64.1-cp310-cp310-win32.whl", hash = "sha256:1262402af5a511c245c3ae918167eca57342c72320dffae5d9b51840c4b2f86d"}, + {file = "grpcio-1.64.1-cp310-cp310-win_amd64.whl", hash = "sha256:19264fc964576ddb065368cae953f8d0514ecc6cb3da8903766d9fb9d4554c33"}, + {file = "grpcio-1.64.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:58b1041e7c870bb30ee41d3090cbd6f0851f30ae4eb68228955d973d3efa2e61"}, + {file = "grpcio-1.64.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bbc5b1d78a7822b0a84c6f8917faa986c1a744e65d762ef6d8be9d75677af2ca"}, + {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5841dd1f284bd1b3d8a6eca3a7f062b06f1eec09b184397e1d1d43447e89a7ae"}, + {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8caee47e970b92b3dd948371230fcceb80d3f2277b3bf7fbd7c0564e7d39068e"}, + {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73819689c169417a4f978e562d24f2def2be75739c4bed1992435d007819da1b"}, + {file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6503b64c8b2dfad299749cad1b595c650c91e5b2c8a1b775380fcf8d2cbba1e9"}, + {file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1de403fc1305fd96cfa75e83be3dee8538f2413a6b1685b8452301c7ba33c294"}, + {file = "grpcio-1.64.1-cp311-cp311-win32.whl", hash = "sha256:d4d29cc612e1332237877dfa7fe687157973aab1d63bd0f84cf06692f04c0367"}, + {file = "grpcio-1.64.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e56462b05a6f860b72f0fa50dca06d5b26543a4e88d0396259a07dc30f4e5aa"}, + {file = "grpcio-1.64.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59"}, + {file = "grpcio-1.64.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1"}, + {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb"}, + {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb"}, + {file = "grpcio-1.64.1-cp312-cp312-win32.whl", hash = "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027"}, + {file = "grpcio-1.64.1-cp312-cp312-win_amd64.whl", hash = "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6"}, + {file = "grpcio-1.64.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:a011ac6c03cfe162ff2b727bcb530567826cec85eb8d4ad2bfb4bd023287a52d"}, + {file = "grpcio-1.64.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4d6dab6124225496010bd22690f2d9bd35c7cbb267b3f14e7a3eb05c911325d4"}, + {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:a5e771d0252e871ce194d0fdcafd13971f1aae0ddacc5f25615030d5df55c3a2"}, + {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c3c1b90ab93fed424e454e93c0ed0b9d552bdf1b0929712b094f5ecfe7a23ad"}, + {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20405cb8b13fd779135df23fabadc53b86522d0f1cba8cca0e87968587f50650"}, + {file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0cc79c982ccb2feec8aad0e8fb0d168bcbca85bc77b080d0d3c5f2f15c24ea8f"}, + {file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a3a035c37ce7565b8f4f35ff683a4db34d24e53dc487e47438e434eb3f701b2a"}, + {file = "grpcio-1.64.1-cp38-cp38-win32.whl", hash = "sha256:1257b76748612aca0f89beec7fa0615727fd6f2a1ad580a9638816a4b2eb18fd"}, + {file = "grpcio-1.64.1-cp38-cp38-win_amd64.whl", hash = "sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122"}, + {file = "grpcio-1.64.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:75dbbf415026d2862192fe1b28d71f209e2fd87079d98470db90bebe57b33179"}, + {file = "grpcio-1.64.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e3d9f8d1221baa0ced7ec7322a981e28deb23749c76eeeb3d33e18b72935ab62"}, + {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:5f8b75f64d5d324c565b263c67dbe4f0af595635bbdd93bb1a88189fc62ed2e5"}, + {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c84ad903d0d94311a2b7eea608da163dace97c5fe9412ea311e72c3684925602"}, + {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:940e3ec884520155f68a3b712d045e077d61c520a195d1a5932c531f11883489"}, + {file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309"}, + {file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac15b6c2c80a4d1338b04d42a02d376a53395ddf0ec9ab157cbaf44191f3ffdd"}, + {file = "grpcio-1.64.1-cp39-cp39-win32.whl", hash = "sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040"}, + {file = "grpcio-1.64.1-cp39-cp39-win_amd64.whl", hash = "sha256:ed6091fa0adcc7e4ff944090cf203a52da35c37a130efa564ded02b7aff63bcd"}, + {file = "grpcio-1.64.1.tar.gz", hash = "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.64.0)"] +protobuf = ["grpcio-tools (>=1.64.1)"] [[package]] name = "grpcio-reflection" @@ -874,13 +864,13 @@ files = [ [[package]] name = "huggingface-hub" -version = "0.23.1" +version = "0.23.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, - {file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, + {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, + {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, ] [package.dependencies] @@ -917,6 +907,25 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] +[[package]] +name = "importlib-metadata" +version = "7.1.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1564,87 +1573,97 @@ files = [ [[package]] name = "opentelemetry-api" -version = "1.15.0" +version = "1.25.0" description = "OpenTelemetry Python API" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_api-1.15.0-py3-none-any.whl", hash = "sha256:e6c2d2e42140fd396e96edf75a7ceb11073f4efb4db87565a431cc9d0f93f2e0"}, - {file = "opentelemetry_api-1.15.0.tar.gz", hash = "sha256:79ab791b4aaad27acc3dc3ba01596db5b5aac2ef75c70622c6038051d6c2cded"}, + {file = "opentelemetry_api-1.25.0-py3-none-any.whl", hash = "sha256:757fa1aa020a0f8fa139f8959e53dec2051cc26b832e76fa839a6d76ecefd737"}, + {file = "opentelemetry_api-1.25.0.tar.gz", hash = "sha256:77c4985f62f2614e42ce77ee4c9da5fa5f0bc1e1821085e9a47533a9323ae869"}, ] [package.dependencies] deprecated = ">=1.2.6" -setuptools = ">=16.0" +importlib-metadata = ">=6.0,<=7.1" [[package]] name = "opentelemetry-exporter-otlp" -version = "1.15.0" +version = "1.25.0" description = "OpenTelemetry Collector Exporters" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp-1.15.0-py3-none-any.whl", hash = "sha256:79f22748b6a54808a0448093dfa189c8490e729f67c134d4c992533d9393b33e"}, - {file = "opentelemetry_exporter_otlp-1.15.0.tar.gz", hash = "sha256:4f7c49751d9720e2e726e13b0bb958ccade4e29122c305d92c033da432c8d2c5"}, + {file = "opentelemetry_exporter_otlp-1.25.0-py3-none-any.whl", hash = "sha256:d67a831757014a3bc3174e4cd629ae1493b7ba8d189e8a007003cacb9f1a6b60"}, + {file = "opentelemetry_exporter_otlp-1.25.0.tar.gz", hash = "sha256:ce03199c1680a845f82e12c0a6a8f61036048c07ec7a0bd943142aca8fa6ced0"}, ] [package.dependencies] -opentelemetry-exporter-otlp-proto-grpc = "1.15.0" -opentelemetry-exporter-otlp-proto-http = "1.15.0" +opentelemetry-exporter-otlp-proto-grpc = "1.25.0" +opentelemetry-exporter-otlp-proto-http = "1.25.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.25.0" +description = "OpenTelemetry Protobuf encoding" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_exporter_otlp_proto_common-1.25.0-py3-none-any.whl", hash = "sha256:15637b7d580c2675f70246563363775b4e6de947871e01d0f4e3881d1848d693"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.25.0.tar.gz", hash = "sha256:c93f4e30da4eee02bacd1e004eb82ce4da143a2f8e15b987a9f603e0a85407d3"}, +] + +[package.dependencies] +opentelemetry-proto = "1.25.0" [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.15.0" +version = "1.25.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0-py3-none-any.whl", hash = "sha256:c2a5492ba7d140109968135d641d06ce3c5bd73c50665f787526065d57d7fd1d"}, - {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0.tar.gz", hash = "sha256:844f2a4bb9bcda34e4eb6fe36765e5031aacb36dc60ed88c90fc246942ea26e7"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0-py3-none-any.whl", hash = "sha256:3131028f0c0a155a64c430ca600fd658e8e37043cb13209f0109db5c1a3e4eb4"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0.tar.gz", hash = "sha256:c0b1661415acec5af87625587efa1ccab68b873745ca0ee96b69bb1042087eac"}, ] [package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" grpcio = ">=1.0.0,<2.0.0" -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-proto = "1.15.0" -opentelemetry-sdk = ">=1.12,<2.0" - -[package.extras] -test = ["pytest-grpc"] +opentelemetry-api = ">=1.15,<2.0" +opentelemetry-exporter-otlp-proto-common = "1.25.0" +opentelemetry-proto = "1.25.0" +opentelemetry-sdk = ">=1.25.0,<1.26.0" [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.15.0" +version = "1.25.0" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_http-1.15.0-py3-none-any.whl", hash = "sha256:3ec2a02196c8a54bf5cbf7fe623a5238625638e83b6047a983bdf96e2bbb74c0"}, - {file = "opentelemetry_exporter_otlp_proto_http-1.15.0.tar.gz", hash = "sha256:11b2c814249a49b22f6cca7a06b05701f561d577b747f3660dfd67b6eb9daf9c"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.25.0-py3-none-any.whl", hash = "sha256:2eca686ee11b27acd28198b3ea5e5863a53d1266b91cda47c839d95d5e0541a6"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.25.0.tar.gz", hash = "sha256:9f8723859e37c75183ea7afa73a3542f01d0fd274a5b97487ea24cb683d7d684"}, ] [package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-proto = "1.15.0" -opentelemetry-sdk = ">=1.12,<2.0" +opentelemetry-api = ">=1.15,<2.0" +opentelemetry-exporter-otlp-proto-common = "1.25.0" +opentelemetry-proto = "1.25.0" +opentelemetry-sdk = ">=1.25.0,<1.26.0" requests = ">=2.7,<3.0" -[package.extras] -test = ["responses (==0.22.0)"] - [[package]] name = "opentelemetry-instrumentation" -version = "0.36b0" +version = "0.46b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation-0.36b0-py3-none-any.whl", hash = "sha256:83ba4ae7d5292b5b33e0f851cc5c76d8f91196b9b3527800fc13855c33383ac2"}, - {file = "opentelemetry_instrumentation-0.36b0.tar.gz", hash = "sha256:e3ddac9b3b93408ef26c8ecbf38f717042977e16381bb4cd329a5b4cf16998cf"}, + {file = "opentelemetry_instrumentation-0.46b0-py3-none-any.whl", hash = "sha256:89cd721b9c18c014ca848ccd11181e6b3fd3f6c7669e35d59c48dc527408c18b"}, + {file = "opentelemetry_instrumentation-0.46b0.tar.gz", hash = "sha256:974e0888fb2a1e01c38fbacc9483d024bb1132aad92d6d24e2e5543887a7adda"}, ] [package.dependencies] @@ -1654,35 +1673,33 @@ wrapt = ">=1.0.0,<2.0.0" [[package]] name = "opentelemetry-instrumentation-grpc" -version = "0.36b0" +version = "0.46b0" description = "OpenTelemetry gRPC instrumentation" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_grpc-0.36b0-py3-none-any.whl", hash = "sha256:eaa246ed2083c97b13bab2555cb9d170e8433230a31476c4cab8a17fa03380a4"}, - {file = "opentelemetry_instrumentation_grpc-0.36b0.tar.gz", hash = "sha256:dc89447c9eb6ea868970f6c13b4ffdac182cdd5a41dd215a0f5393ca6375be55"}, + {file = "opentelemetry_instrumentation_grpc-0.46b0-py3-none-any.whl", hash = "sha256:cccfb28db07c28849709f2dcf330237fae0fca9f86971bfce27b28bb9a8b0577"}, + {file = "opentelemetry_instrumentation_grpc-0.46b0.tar.gz", hash = "sha256:9c5738592cf82672805099826b676d352324b54e03f9ac72a1368ba0605d6ff9"}, ] [package.dependencies] opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.36b0" -opentelemetry-sdk = ">=1.12,<2.0" -opentelemetry-semantic-conventions = "0.36b0" +opentelemetry-instrumentation = "0.46b0" +opentelemetry-semantic-conventions = "0.46b0" wrapt = ">=1.0.0,<2.0.0" [package.extras] instruments = ["grpcio (>=1.27,<2.0)"] -test = ["opentelemetry-instrumentation-grpc[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.36b0)", "protobuf (>=3.13,<4.0)"] [[package]] name = "opentelemetry-proto" -version = "1.15.0" +version = "1.25.0" description = "OpenTelemetry Python Proto" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_proto-1.15.0-py3-none-any.whl", hash = "sha256:044b6d044b4d10530f250856f933442b8753a17f94ae37c207607f733fb9a844"}, - {file = "opentelemetry_proto-1.15.0.tar.gz", hash = "sha256:9c4008e40ac8cab359daac283fbe7002c5c29c77ea2674ad5626a249e64e0101"}, + {file = "opentelemetry_proto-1.25.0-py3-none-any.whl", hash = "sha256:f07e3341c78d835d9b86665903b199893befa5e98866f63d22b00d0b7ca4972f"}, + {file = "opentelemetry_proto-1.25.0.tar.gz", hash = "sha256:35b6ef9dc4a9f7853ecc5006738ad40443701e52c26099e197895cbda8b815a3"}, ] [package.dependencies] @@ -1690,41 +1707,43 @@ protobuf = ">=3.19,<5.0" [[package]] name = "opentelemetry-sdk" -version = "1.15.0" +version = "1.25.0" description = "OpenTelemetry Python SDK" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_sdk-1.15.0-py3-none-any.whl", hash = "sha256:555c533e9837766119bbccc7a80458c9971d853a6f1da683a2246cd5e53b4645"}, - {file = "opentelemetry_sdk-1.15.0.tar.gz", hash = "sha256:98dbffcfeebcbff12c0c974292d6ea603180a145904cf838b1fe4d5c99078425"}, + {file = "opentelemetry_sdk-1.25.0-py3-none-any.whl", hash = "sha256:d97ff7ec4b351692e9d5a15af570c693b8715ad78b8aafbec5c7100fe966b4c9"}, + {file = "opentelemetry_sdk-1.25.0.tar.gz", hash = "sha256:ce7fc319c57707ef5bf8b74fb9f8ebdb8bfafbe11898410e0d2a761d08a98ec7"}, ] [package.dependencies] -opentelemetry-api = "1.15.0" -opentelemetry-semantic-conventions = "0.36b0" -setuptools = ">=16.0" +opentelemetry-api = "1.25.0" +opentelemetry-semantic-conventions = "0.46b0" typing-extensions = ">=3.7.4" [[package]] name = "opentelemetry-semantic-conventions" -version = "0.36b0" +version = "0.46b0" description = "OpenTelemetry Semantic Conventions" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_semantic_conventions-0.36b0-py3-none-any.whl", hash = "sha256:adc05635e87b9d3e007c9f530eed487fc3ef2177d02f82f674f28ebf9aff8243"}, - {file = "opentelemetry_semantic_conventions-0.36b0.tar.gz", hash = "sha256:829dc221795467d98b773c04096e29be038d77526dc8d6ac76f546fb6279bf01"}, + {file = "opentelemetry_semantic_conventions-0.46b0-py3-none-any.whl", hash = "sha256:6daef4ef9fa51d51855d9f8e0ccd3a1bd59e0e545abe99ac6203804e36ab3e07"}, + {file = "opentelemetry_semantic_conventions-0.46b0.tar.gz", hash = "sha256:fbc982ecbb6a6e90869b15c1673be90bd18c8a56ff1cffc0864e38e2edffaefa"}, ] +[package.dependencies] +opentelemetry-api = "1.25.0" + [[package]] name = "outlines" -version = "0.0.36" +version = "0.0.34" description = "Probabilistic Generative Model Programming" optional = true python-versions = ">=3.8" files = [ - {file = "outlines-0.0.36-py3-none-any.whl", hash = "sha256:afa02ca5c449c47731fa06af66d13c2f5ee8b30f8b82b4db90e08215d6f111d1"}, - {file = "outlines-0.0.36.tar.gz", hash = "sha256:3cffb43143548cd78c6061990feb461cffd5479999391b8390471ea839c2d46e"}, + {file = "outlines-0.0.34-py3-none-any.whl", hash = "sha256:911588a7e64a4f193b97fb4c501d98ccfd4e95a98f6a3ada67a280bf0c373c50"}, + {file = "outlines-0.0.34.tar.gz", hash = "sha256:594e7204c770b47a62eb5c2ba7d25ea0ab2e16882b5f04556712a0228d3d3309"}, ] [package.dependencies] @@ -1747,7 +1766,7 @@ transformers = "*" [package.extras] serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] -test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "openai (>=1.0.0)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python (>=0.2.42)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] [[package]] name = "packaging" @@ -2086,18 +2105,18 @@ numpy = ">=1.16.6" [[package]] name = "pydantic" -version = "2.7.1" +version = "2.7.3" description = "Data validation using Python type hints" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic-2.7.1-py3-none-any.whl", hash = "sha256:e029badca45266732a9a79898a15ae2e8b14840b1eabbb25844be28f0b33f3d5"}, - {file = "pydantic-2.7.1.tar.gz", hash = "sha256:e9dbb5eada8abe4d9ae5f46b9939aead650cd2b68f249bb3a8139dbe125803cc"}, + {file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"}, + {file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"}, ] [package.dependencies] annotated-types = ">=0.4.0" -pydantic-core = "2.18.2" +pydantic-core = "2.18.4" typing-extensions = ">=4.6.1" [package.extras] @@ -2105,90 +2124,90 @@ email = ["email-validator (>=2.0.0)"] [[package]] name = "pydantic-core" -version = "2.18.2" +version = "2.18.4" description = "Core functionality for Pydantic validation and serialization" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.18.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9e08e867b306f525802df7cd16c44ff5ebbe747ff0ca6cf3fde7f36c05a59a81"}, - {file = "pydantic_core-2.18.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f0a21cbaa69900cbe1a2e7cad2aa74ac3cf21b10c3efb0fa0b80305274c0e8a2"}, - {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0680b1f1f11fda801397de52c36ce38ef1c1dc841a0927a94f226dea29c3ae3d"}, - {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:95b9d5e72481d3780ba3442eac863eae92ae43a5f3adb5b4d0a1de89d42bb250"}, - {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fcf5cd9c4b655ad666ca332b9a081112cd7a58a8b5a6ca7a3104bc950f2038"}, - {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b5155ff768083cb1d62f3e143b49a8a3432e6789a3abee8acd005c3c7af1c74"}, - {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:553ef617b6836fc7e4df130bb851e32fe357ce36336d897fd6646d6058d980af"}, - {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89ed9eb7d616ef5714e5590e6cf7f23b02d0d539767d33561e3675d6f9e3857"}, - {file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:75f7e9488238e920ab6204399ded280dc4c307d034f3924cd7f90a38b1829563"}, - {file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ef26c9e94a8c04a1b2924149a9cb081836913818e55681722d7f29af88fe7b38"}, - {file = "pydantic_core-2.18.2-cp310-none-win32.whl", hash = "sha256:182245ff6b0039e82b6bb585ed55a64d7c81c560715d1bad0cbad6dfa07b4027"}, - {file = "pydantic_core-2.18.2-cp310-none-win_amd64.whl", hash = "sha256:e23ec367a948b6d812301afc1b13f8094ab7b2c280af66ef450efc357d2ae543"}, - {file = "pydantic_core-2.18.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:219da3f096d50a157f33645a1cf31c0ad1fe829a92181dd1311022f986e5fbe3"}, - {file = "pydantic_core-2.18.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cc1cfd88a64e012b74e94cd00bbe0f9c6df57049c97f02bb07d39e9c852e19a4"}, - {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b7133a6e6aeb8df37d6f413f7705a37ab4031597f64ab56384c94d98fa0e90"}, - {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:224c421235f6102e8737032483f43c1a8cfb1d2f45740c44166219599358c2cd"}, - {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b14d82cdb934e99dda6d9d60dc84a24379820176cc4a0d123f88df319ae9c150"}, - {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2728b01246a3bba6de144f9e3115b532ee44bd6cf39795194fb75491824a1413"}, - {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:470b94480bb5ee929f5acba6995251ada5e059a5ef3e0dfc63cca287283ebfa6"}, - {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:997abc4df705d1295a42f95b4eec4950a37ad8ae46d913caeee117b6b198811c"}, - {file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:75250dbc5290e3f1a0f4618db35e51a165186f9034eff158f3d490b3fed9f8a0"}, - {file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4456f2dca97c425231d7315737d45239b2b51a50dc2b6f0c2bb181fce6207664"}, - {file = "pydantic_core-2.18.2-cp311-none-win32.whl", hash = "sha256:269322dcc3d8bdb69f054681edff86276b2ff972447863cf34c8b860f5188e2e"}, - {file = "pydantic_core-2.18.2-cp311-none-win_amd64.whl", hash = "sha256:800d60565aec896f25bc3cfa56d2277d52d5182af08162f7954f938c06dc4ee3"}, - {file = "pydantic_core-2.18.2-cp311-none-win_arm64.whl", hash = "sha256:1404c69d6a676245199767ba4f633cce5f4ad4181f9d0ccb0577e1f66cf4c46d"}, - {file = "pydantic_core-2.18.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:fb2bd7be70c0fe4dfd32c951bc813d9fe6ebcbfdd15a07527796c8204bd36242"}, - {file = "pydantic_core-2.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6132dd3bd52838acddca05a72aafb6eab6536aa145e923bb50f45e78b7251043"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d904828195733c183d20a54230c0df0eb46ec746ea1a666730787353e87182"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9bd70772c720142be1020eac55f8143a34ec9f82d75a8e7a07852023e46617f"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b8ed04b3582771764538f7ee7001b02e1170223cf9b75dff0bc698fadb00cf3"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6dac87ddb34aaec85f873d737e9d06a3555a1cc1a8e0c44b7f8d5daeb89d86f"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca4ae5a27ad7a4ee5170aebce1574b375de390bc01284f87b18d43a3984df72"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:886eec03591b7cf058467a70a87733b35f44707bd86cf64a615584fd72488b7c"}, - {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca7b0c1f1c983e064caa85f3792dd2fe3526b3505378874afa84baf662e12241"}, - {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b4356d3538c3649337df4074e81b85f0616b79731fe22dd11b99499b2ebbdf3"}, - {file = "pydantic_core-2.18.2-cp312-none-win32.whl", hash = "sha256:8b172601454f2d7701121bbec3425dd71efcb787a027edf49724c9cefc14c038"}, - {file = "pydantic_core-2.18.2-cp312-none-win_amd64.whl", hash = "sha256:b1bd7e47b1558ea872bd16c8502c414f9e90dcf12f1395129d7bb42a09a95438"}, - {file = "pydantic_core-2.18.2-cp312-none-win_arm64.whl", hash = "sha256:98758d627ff397e752bc339272c14c98199c613f922d4a384ddc07526c86a2ec"}, - {file = "pydantic_core-2.18.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9fdad8e35f278b2c3eb77cbdc5c0a49dada440657bf738d6905ce106dc1de439"}, - {file = "pydantic_core-2.18.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d90c3265ae107f91a4f279f4d6f6f1d4907ac76c6868b27dc7fb33688cfb347"}, - {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:390193c770399861d8df9670fb0d1874f330c79caaca4642332df7c682bf6b91"}, - {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:82d5d4d78e4448683cb467897fe24e2b74bb7b973a541ea1dcfec1d3cbce39fb"}, - {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4774f3184d2ef3e14e8693194f661dea5a4d6ca4e3dc8e39786d33a94865cefd"}, - {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4d938ec0adf5167cb335acb25a4ee69a8107e4984f8fbd2e897021d9e4ca21b"}, - {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e8b1be28239fc64a88a8189d1df7fad8be8c1ae47fcc33e43d4be15f99cc70"}, - {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:868649da93e5a3d5eacc2b5b3b9235c98ccdbfd443832f31e075f54419e1b96b"}, - {file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:78363590ef93d5d226ba21a90a03ea89a20738ee5b7da83d771d283fd8a56761"}, - {file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:852e966fbd035a6468fc0a3496589b45e2208ec7ca95c26470a54daed82a0788"}, - {file = "pydantic_core-2.18.2-cp38-none-win32.whl", hash = "sha256:6a46e22a707e7ad4484ac9ee9f290f9d501df45954184e23fc29408dfad61350"}, - {file = "pydantic_core-2.18.2-cp38-none-win_amd64.whl", hash = "sha256:d91cb5ea8b11607cc757675051f61b3d93f15eca3cefb3e6c704a5d6e8440f4e"}, - {file = "pydantic_core-2.18.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ae0a8a797a5e56c053610fa7be147993fe50960fa43609ff2a9552b0e07013e8"}, - {file = "pydantic_core-2.18.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:042473b6280246b1dbf530559246f6842b56119c2926d1e52b631bdc46075f2a"}, - {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a388a77e629b9ec814c1b1e6b3b595fe521d2cdc625fcca26fbc2d44c816804"}, - {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25add29b8f3b233ae90ccef2d902d0ae0432eb0d45370fe315d1a5cf231004b"}, - {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f459a5ce8434614dfd39bbebf1041952ae01da6bed9855008cb33b875cb024c0"}, - {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eff2de745698eb46eeb51193a9f41d67d834d50e424aef27df2fcdee1b153845"}, - {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8309f67285bdfe65c372ea3722b7a5642680f3dba538566340a9d36e920b5f0"}, - {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f93a8a2e3938ff656a7c1bc57193b1319960ac015b6e87d76c76bf14fe0244b4"}, - {file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:22057013c8c1e272eb8d0eebc796701167d8377441ec894a8fed1af64a0bf399"}, - {file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfeecd1ac6cc1fb2692c3d5110781c965aabd4ec5d32799773ca7b1456ac636b"}, - {file = "pydantic_core-2.18.2-cp39-none-win32.whl", hash = "sha256:0d69b4c2f6bb3e130dba60d34c0845ba31b69babdd3f78f7c0c8fae5021a253e"}, - {file = "pydantic_core-2.18.2-cp39-none-win_amd64.whl", hash = "sha256:d9319e499827271b09b4e411905b24a426b8fb69464dfa1696258f53a3334641"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a1874c6dd4113308bd0eb568418e6114b252afe44319ead2b4081e9b9521fe75"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:ccdd111c03bfd3666bd2472b674c6899550e09e9f298954cfc896ab92b5b0e6d"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e18609ceaa6eed63753037fc06ebb16041d17d28199ae5aba0052c51449650a9"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5c584d357c4e2baf0ff7baf44f4994be121e16a2c88918a5817331fc7599d7"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43f0f463cf89ace478de71a318b1b4f05ebc456a9b9300d027b4b57c1a2064fb"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e1b395e58b10b73b07b7cf740d728dd4ff9365ac46c18751bf8b3d8cca8f625a"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0098300eebb1c837271d3d1a2cd2911e7c11b396eac9661655ee524a7f10587b"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:36789b70d613fbac0a25bb07ab3d9dba4d2e38af609c020cf4d888d165ee0bf3"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f9a801e7c8f1ef8718da265bba008fa121243dfe37c1cea17840b0944dfd72c"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3a6515ebc6e69d85502b4951d89131ca4e036078ea35533bb76327f8424531ce"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20aca1e2298c56ececfd8ed159ae4dde2df0781988c97ef77d5c16ff4bd5b400"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:223ee893d77a310a0391dca6df00f70bbc2f36a71a895cecd9a0e762dc37b349"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2334ce8c673ee93a1d6a65bd90327588387ba073c17e61bf19b4fd97d688d63c"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cbca948f2d14b09d20268cda7b0367723d79063f26c4ffc523af9042cad95592"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b3ef08e20ec49e02d5c6717a91bb5af9b20f1805583cb0adfe9ba2c6b505b5ae"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6fdc8627910eed0c01aed6a390a252fe3ea6d472ee70fdde56273f198938374"}, - {file = "pydantic_core-2.18.2.tar.gz", hash = "sha256:2e29d20810dfc3043ee13ac7d9e25105799817683348823f305ab3f349b9386e"}, + {file = "pydantic_core-2.18.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:f76d0ad001edd426b92233d45c746fd08f467d56100fd8f30e9ace4b005266e4"}, + {file = "pydantic_core-2.18.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:59ff3e89f4eaf14050c8022011862df275b552caef8082e37b542b066ce1ff26"}, + {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a55b5b16c839df1070bc113c1f7f94a0af4433fcfa1b41799ce7606e5c79ce0a"}, + {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4d0dcc59664fcb8974b356fe0a18a672d6d7cf9f54746c05f43275fc48636851"}, + {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8951eee36c57cd128f779e641e21eb40bc5073eb28b2d23f33eb0ef14ffb3f5d"}, + {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4701b19f7e3a06ea655513f7938de6f108123bf7c86bbebb1196eb9bd35cf724"}, + {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e00a3f196329e08e43d99b79b286d60ce46bed10f2280d25a1718399457e06be"}, + {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:97736815b9cc893b2b7f663628e63f436018b75f44854c8027040e05230eeddb"}, + {file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6891a2ae0e8692679c07728819b6e2b822fb30ca7445f67bbf6509b25a96332c"}, + {file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bc4ff9805858bd54d1a20efff925ccd89c9d2e7cf4986144b30802bf78091c3e"}, + {file = "pydantic_core-2.18.4-cp310-none-win32.whl", hash = "sha256:1b4de2e51bbcb61fdebd0ab86ef28062704f62c82bbf4addc4e37fa4b00b7cbc"}, + {file = "pydantic_core-2.18.4-cp310-none-win_amd64.whl", hash = "sha256:6a750aec7bf431517a9fd78cb93c97b9b0c496090fee84a47a0d23668976b4b0"}, + {file = "pydantic_core-2.18.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:942ba11e7dfb66dc70f9ae66b33452f51ac7bb90676da39a7345e99ffb55402d"}, + {file = "pydantic_core-2.18.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b2ebef0e0b4454320274f5e83a41844c63438fdc874ea40a8b5b4ecb7693f1c4"}, + {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a642295cd0c8df1b86fc3dced1d067874c353a188dc8e0f744626d49e9aa51c4"}, + {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f09baa656c904807e832cf9cce799c6460c450c4ad80803517032da0cd062e2"}, + {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98906207f29bc2c459ff64fa007afd10a8c8ac080f7e4d5beff4c97086a3dabd"}, + {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19894b95aacfa98e7cb093cd7881a0c76f55731efad31073db4521e2b6ff5b7d"}, + {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fbbdc827fe5e42e4d196c746b890b3d72876bdbf160b0eafe9f0334525119c8"}, + {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f85d05aa0918283cf29a30b547b4df2fbb56b45b135f9e35b6807cb28bc47951"}, + {file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e85637bc8fe81ddb73fda9e56bab24560bdddfa98aa64f87aaa4e4b6730c23d2"}, + {file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2f5966897e5461f818e136b8451d0551a2e77259eb0f73a837027b47dc95dab9"}, + {file = "pydantic_core-2.18.4-cp311-none-win32.whl", hash = "sha256:44c7486a4228413c317952e9d89598bcdfb06399735e49e0f8df643e1ccd0558"}, + {file = "pydantic_core-2.18.4-cp311-none-win_amd64.whl", hash = "sha256:8a7164fe2005d03c64fd3b85649891cd4953a8de53107940bf272500ba8a788b"}, + {file = "pydantic_core-2.18.4-cp311-none-win_arm64.whl", hash = "sha256:4e99bc050fe65c450344421017f98298a97cefc18c53bb2f7b3531eb39bc7805"}, + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"}, + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"}, + {file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"}, + {file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"}, + {file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"}, + {file = "pydantic_core-2.18.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:823be1deb01793da05ecb0484d6c9e20baebb39bd42b5d72636ae9cf8350dbd2"}, + {file = "pydantic_core-2.18.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ebef0dd9bf9b812bf75bda96743f2a6c5734a02092ae7f721c048d156d5fabae"}, + {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae1d6df168efb88d7d522664693607b80b4080be6750c913eefb77e34c12c71a"}, + {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9899c94762343f2cc2fc64c13e7cae4c3cc65cdfc87dd810a31654c9b7358cc"}, + {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99457f184ad90235cfe8461c4d70ab7dd2680e28821c29eca00252ba90308c78"}, + {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18f469a3d2a2fdafe99296a87e8a4c37748b5080a26b806a707f25a902c040a8"}, + {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7cdf28938ac6b8b49ae5e92f2735056a7ba99c9b110a474473fd71185c1af5d"}, + {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:938cb21650855054dc54dfd9120a851c974f95450f00683399006aa6e8abb057"}, + {file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:44cd83ab6a51da80fb5adbd9560e26018e2ac7826f9626bc06ca3dc074cd198b"}, + {file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:972658f4a72d02b8abfa2581d92d59f59897d2e9f7e708fdabe922f9087773af"}, + {file = "pydantic_core-2.18.4-cp38-none-win32.whl", hash = "sha256:1d886dc848e60cb7666f771e406acae54ab279b9f1e4143babc9c2258213daa2"}, + {file = "pydantic_core-2.18.4-cp38-none-win_amd64.whl", hash = "sha256:bb4462bd43c2460774914b8525f79b00f8f407c945d50881568f294c1d9b4443"}, + {file = "pydantic_core-2.18.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:44a688331d4a4e2129140a8118479443bd6f1905231138971372fcde37e43528"}, + {file = "pydantic_core-2.18.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a2fdd81edd64342c85ac7cf2753ccae0b79bf2dfa063785503cb85a7d3593223"}, + {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86110d7e1907ab36691f80b33eb2da87d780f4739ae773e5fc83fb272f88825f"}, + {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46387e38bd641b3ee5ce247563b60c5ca098da9c56c75c157a05eaa0933ed154"}, + {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:123c3cec203e3f5ac7b000bd82235f1a3eced8665b63d18be751f115588fea30"}, + {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc1803ac5c32ec324c5261c7209e8f8ce88e83254c4e1aebdc8b0a39f9ddb443"}, + {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53db086f9f6ab2b4061958d9c276d1dbe3690e8dd727d6abf2321d6cce37fa94"}, + {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:abc267fa9837245cc28ea6929f19fa335f3dc330a35d2e45509b6566dc18be23"}, + {file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a0d829524aaefdebccb869eed855e2d04c21d2d7479b6cada7ace5448416597b"}, + {file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:509daade3b8649f80d4e5ff21aa5673e4ebe58590b25fe42fac5f0f52c6f034a"}, + {file = "pydantic_core-2.18.4-cp39-none-win32.whl", hash = "sha256:ca26a1e73c48cfc54c4a76ff78df3727b9d9f4ccc8dbee4ae3f73306a591676d"}, + {file = "pydantic_core-2.18.4-cp39-none-win_amd64.whl", hash = "sha256:c67598100338d5d985db1b3d21f3619ef392e185e71b8d52bceacc4a7771ea7e"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"}, + {file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"}, ] [package.dependencies] @@ -2406,13 +2425,13 @@ files = [ [[package]] name = "requests" -version = "2.32.2" +version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, - {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -2779,17 +2798,17 @@ files = [ [[package]] name = "sympy" -version = "1.12" +version = "1.12.1" description = "Computer algebra system (CAS) in Python" optional = true python-versions = ">=3.8" files = [ - {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, - {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, + {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, + {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, ] [package.dependencies] -mpmath = ">=0.19" +mpmath = ">=1.1.0,<1.4.0" [[package]] name = "tbb" @@ -3019,13 +3038,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.41.1" +version = "4.41.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.41.1-py3-none-any.whl", hash = "sha256:f0680e0b1a01067eccd11f62f0522409422c7d6f91d532fe0f50b136a406129d"}, - {file = "transformers-4.41.1.tar.gz", hash = "sha256:fa859e4c66f0896633a3bf534e0d9a29a9a88478a49f94c5d8270537dc61cc42"}, + {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, + {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, ] [package.dependencies] @@ -3128,13 +3147,13 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6. [[package]] name = "typing-extensions" -version = "4.12.0" +version = "4.12.1" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, - {file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, + {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, + {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, ] [[package]] @@ -3478,6 +3497,21 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[[package]] +name = "zipp" +version = "3.19.1" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, + {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, +] + +[package.extras] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] + [extras] accelerate = ["accelerate"] bnb = ["bitsandbytes"] @@ -3489,4 +3523,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "06e67944a2b1cf9884a31e771d0e9d89877e9b3c91894982cb67d104cb834758" +content-hash = "f62a7a74e1e1bcb3b7cb4f7da2b538065830748062a2b57fdbb4c76eae5abddc" diff --git a/server/pyproject.toml b/server/pyproject.toml index 6a3a0205..4cfbc1a2 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -9,7 +9,7 @@ text-generation-server = 'text_generation_server.cli:app' [tool.poetry.dependencies] python = ">=3.9,<3.13" -protobuf = "^4.21.7" +protobuf = "^4.25.3" grpcio = "^1.51.1" grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" @@ -19,9 +19,9 @@ accelerate = { version = "^0.29.1", optional = true } bitsandbytes = { version = "^0.43.0", optional = true } safetensors = "^0.4" loguru = "^0.6.0" -opentelemetry-api = "^1.15.0" -opentelemetry-exporter-otlp = "^1.15.0" -opentelemetry-instrumentation-grpc = "^0.36b0" +opentelemetry-api = "^1.25.0" +opentelemetry-exporter-otlp = "^1.25.0" +opentelemetry-instrumentation-grpc = "^0.46b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.19.1" @@ -34,7 +34,7 @@ peft = { version = "^0.10", optional = true } torch = { version = "^2.3.0", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" -outlines= { version = "^0.0.36", optional = true } +outlines= { version = "^0.0.34", optional = true } prometheus-client = "^0.20.0" py-cpuinfo = "^9.0.0" From 353a9669bab848c1d6a5e79173cbbe371103a91b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 4 Jun 2024 23:34:03 +0200 Subject: [PATCH 027/340] Hotfixing `make install`. (#2008) # What does this PR do? Fixes initial and subsequent installs (protection for folder creation should only be for git commit, checking out correct commit should be on both. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/Makefile-flash-att | 10 ++++------ server/Makefile-flash-att-v2 | 11 ++++------- server/Makefile-vllm | 25 +++++++++++-------------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index 5570863b..29e75bc4 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -3,12 +3,10 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec build-flash-attention: if [ ! -d 'flash-attention' ]; then \ pip install -U packaging ninja --no-cache-dir && \ - git clone https://github.com/HazyResearch/flash-attention.git && \ - cd flash-attention && git fetch && git checkout $(flash_att_commit) && \ - MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build; \ + git clone https://github.com/HazyResearch/flash-attention.git; \ fi + cd flash-attention && git fetch && git checkout $(flash_att_commit) && \ + MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build install-flash-attention: build-flash-attention - if [ ! -d 'flash-attention' ]; then \ - cd flash-attntion && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install; \ - fi + cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index b67803fe..ba90a74d 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -5,9 +5,8 @@ build-flash-attention-v2-cuda: pip install -U packaging wheel pip install flash-attn==$(flash_att_v2_commit_cuda) -install-flash-attention-v2-cuda: - pip install -U packaging wheel - pip install flash-attn==$(flash_att_v2_commit_cuda) +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + echo "Flash v2 installed" build-flash-attention-v2-rocm: if [ ! -d 'flash-attention-v2' ]; then \ @@ -18,7 +17,5 @@ build-flash-attention-v2-rocm: fi install-flash-attention-v2-rocm: build-flash-attention-v2-rocm - if [ ! -d 'flash-attention-v2' ]; then \ - cd flash-attention-v2 && \ - GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install; \ - fi + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install diff --git a/server/Makefile-vllm b/server/Makefile-vllm index de3b4611..ded2f5d2 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,26 +1,23 @@ +commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa +commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/Narsil/vllm.git vllm &&\ - cd vllm && \ - git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa &&\ - python setup.py build; \ + git clone https://github.com/Narsil/vllm.git vllm; \ fi + cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build + install-vllm-cuda: build-vllm-cuda - if [ ! -d 'vllm' ]; then \ - cd vllm && pip install -e .; \ - fi + cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . build-vllm-rocm: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/fxmarty/rocm-vllm.git vllm && \ - cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 && \ - PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + git clone https://github.com/fxmarty/rocm-vllm.git vllm; \ fi + cd vllm && git fetch && git checkout $(commit_rocm) && \ + PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build install-vllm-rocm: build-vllm-rocm - if [ ! -d 'vllm' ]; then \ - cd vllm && \ - PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .; \ - fi + cd vllm && git fetch && git checkout $(commit_rocm) && \ + PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e . From cdd120ac022f17819be2c440d5efa84d8dafe7d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 5 Jun 2024 10:45:47 +0200 Subject: [PATCH 028/340] Do not initialize scratch space when there are no ExLlamaV2 layers (#2015) # What does this PR do? Do not attempt to allocate ExLlamaV2 scratch buffers when there are no ExLlama2 layers. Avoids a crash in warmup for models that cannot use exllama when ExLlamaV2 is installed. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../layers/gptq/exllamav2.py | 159 ++++++++++-------- 1 file changed, 89 insertions(+), 70 deletions(-) diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 321ced97..4d45822b 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -1,10 +1,15 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 +from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn from loguru import logger +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight + try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: @@ -15,6 +20,15 @@ except ImportError: none_tensor = torch.empty((1, 1), device="meta") +@dataclass +class _ExtraTensors: + """Additional generated quantizer tensors.""" + + q_group_map: Optional[torch.Tensor] = None + q_invperm: Optional[torch.Tensor] = None + q_perm: Optional[torch.Tensor] = None + + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) @@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): return output.view(output_shape) -# Group map needed for irregular group sizes - - -def make_group_map(q_groups, num_qrows): - +def make_group_map(q_groups: torch.Tensor, num_qrows: int): gr = q_groups.tolist() group_map = [] num_groups = len(gr) // 2 @@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows): # Create Q matrix -def ext_make_q_matrix(w: dict, temp_dq, key: str = None): +def ext_make_q_matrix( + w: Exl2Weight | GPTQWeight, + extra: _ExtraTensors, + temp_dq, + key: Optional[str] = None, +): """ Create Q matrix """ # EXL2 - # won't work as the moment because the tensors are not the same. - if "q_weight" in w: - w["q_scale_max"] /= 256 - w["q_perm"] = w["q_perm"].short() - w["q_invperm"] = w["q_invperm"].short() - - if "q_group_map" not in w: - w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) + if isinstance(w, Exl2Weight): + extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) + extra.q_perm = torch.argsort(w.q_invperm).short() return make_q_matrix( - w["q_weight"], - w["q_perm"], - w["q_invperm"], - w["q_scale"], - w["q_scale_max"], - w["q_groups"], - w["q_group_map"], + w.q_weight, + extra.q_perm, + w.q_invperm, + w.q_scale, + w.q_scale_max, + w.q_groups, + extra.q_group_map, none_tensor, none_tensor, none_tensor, temp_dq, ) # GPTQ - elif "qweight" in w: - if w["scales"].dtype == torch.float: - w["scales"] = w["scales"].half() + elif isinstance(w, GPTQWeight): + if w.scales.dtype == torch.float: + w.scales = w.scales.half() # GPTQ with g_idx (act_order) - if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): - w["q_perm"] = torch.empty( - (w["qweight"].shape[0] * 8,), + if w.g_idx is not None and not (w.g_idx == 0).all().item(): + extra.q_perm = torch.empty( + (w.qweight.shape[0] * 8,), dtype=torch.short, - device=w["qweight"].device, + device=w.qweight.device, ) - w["q_invperm"] = torch.empty_like(w["q_perm"]) + extra.q_invperm = torch.empty_like(extra.q_perm) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. return make_q_matrix( - w["qweight"], - w["q_perm"], - w["q_invperm"], + w.qweight, + extra.q_perm, + extra.q_invperm, none_tensor, none_tensor, none_tensor, none_tensor, - w["qzeros"], - w["scales"], - w["g_idx"].cpu(), + w.qzeros, + w.scales, + w.g_idx.cpu(), temp_dq, ) # GPTQ without g_idx else: return make_q_matrix( - w["qweight"], + w.qweight, none_tensor, none_tensor, none_tensor, none_tensor, none_tensor, none_tensor, - w["qzeros"], - w["scales"], + w.qzeros, + w.scales, none_tensor, temp_dq, ) @@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): DEVICE = None -FIXED_BYTES = 0 LAYERS = [] @@ -134,8 +143,19 @@ def set_device(device): def create_exllama_buffers(max_total_tokens: int): - global FIXED_BYTES, LAYERS, DEVICE - temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) + global LAYERS, DEVICE + + # No need to initialize scratch space if there are no layers + # that use ExLLamav2. + if len(LAYERS) == 0: + return + + # Find the size of the scratch space. + scratch_bytes = max( + layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) + for layer in LAYERS + ) + temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) for layer in LAYERS: layer.post_init(temp_dq) @@ -146,49 +166,48 @@ class QuantLinear(nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" - # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + def __init__( + self, + weight: Exl2Weight | GPTQWeight, + bias: torch.Tensor, + ): super().__init__() - if bits != 4: - raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." - ) + self.q_handle = None - self.q_tensors = None - self.bits = bits - self.maxq = 2**self.bits - 1 - self.infeatures = qweight.shape[0] // self.bits * 32 - self.outfeatures = qweight.shape[1] + self.q_tensors = weight + self.extra_tensors = _ExtraTensors() + + if isinstance(weight, Exl2Weight): + self.infeatures = weight.q_invperm.shape[0] + self.outfeatures = weight.q_weight.shape[1] + elif isinstance(weight, GPTQWeight): + if weight.bits != 4: + raise ValueError( + f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." + ) + + self.infeatures = weight.qweight.shape[0] // weight.bits * 32 + self.outfeatures = weight.qweight.shape[1] + self.padding = -self.outfeatures % 32 self.outfeatures = self.outfeatures + self.padding - self.device = qweight.device - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx + self.device = weight.device self.bias = bias if bias is not None else None - self.group_size = groupsize - global FIXED_BYTES, LAYERS - FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) + global LAYERS LAYERS.append(self) def post_init(self, temp_dq): - assert self.qweight.device.type == "cuda" - assert self.qweight.device.index is not None - self.q_tensors = { - "qweight": self.qweight, - "qzeros": self.qzeros, - "scales": self.scales, - "g_idx": self.g_idx, - } + device = self.q_tensors.device + assert device.type == "cuda" + assert device.index is not None temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, # and `Memory access fault by GPU node-2` will EAT you. self.temp_dq = temp_dq - self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) def forward(self, x, force_cuda=False): output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) @@ -203,7 +222,7 @@ class QuantLinear(nn.Module): def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): + def scratch_space_fixed(self, max_input_len, max_batch_size): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) From 20df9234a9f736de45f8d6992f96619b62dded8b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 5 Jun 2024 12:18:38 +0200 Subject: [PATCH 029/340] feat: move allocation logic to rust (#1835) Close #2007 --- Cargo.toml | 7 +- Dockerfile_amd | 10 +- Dockerfile_intel | 10 +- benchmark/src/generation.rs | 3 + proto/v3/generate.proto | 6 + router/client/src/v3/client.rs | 6 +- router/client/src/v3/sharded_client.rs | 4 + router/src/infer/v3/block_allocator.rs | 136 +++++ router/src/infer/v3/mod.rs | 1 + router/src/infer/v3/queue.rs | 218 +++++--- router/src/infer/v3/scheduler.rs | 9 +- .../models/cache_manager.py | 140 ----- .../custom_modeling/flash_cohere_modeling.py | 1 + .../custom_modeling/flash_dbrx_modeling.py | 1 + .../custom_modeling/flash_gemma_modeling.py | 1 + .../custom_modeling/flash_neox_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 1 + .../custom_modeling/flash_rw_modeling.py | 1 + .../flash_santacoder_modeling.py | 1 + .../models/flash_causal_lm.py | 268 ++++++---- .../models/flash_mistral.py | 497 +----------------- .../models/flash_qwen2.py | 5 +- .../models/flash_starcoder2.py | 5 +- .../models/vlm_causal_lm.py | 12 +- 24 files changed, 499 insertions(+), 845 deletions(-) create mode 100644 router/src/infer/v3/block_allocator.rs delete mode 100644 server/text_generation_server/models/cache_manager.py diff --git a/Cargo.toml b/Cargo.toml index 88497390..8a1fce05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,12 @@ incremental = true inherits = "release" debug = 1 incremental = true +panic = "abort" + +[profile.release-opt] +inherits = "release" +debug = 0 +incremental = false lto = "fat" opt-level = 3 codegen-units = 1 -panic = "abort" diff --git a/Dockerfile_amd b/Dockerfile_amd index 92dd0ea8..b0d181ea 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -33,7 +33,7 @@ COPY proto proto COPY benchmark benchmark COPY router router COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base @@ -193,11 +193,11 @@ RUN cd server && \ pip install ".[accelerate, peft, outlines]" --no-cache-dir # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher # AWS Sagemaker compatible image FROM base as sagemaker diff --git a/Dockerfile_intel b/Dockerfile_intel index 9c9b5c16..0a700003 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -32,7 +32,7 @@ COPY proto proto COPY benchmark benchmark COPY router router COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image for Intel @@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp ENV CCL_ZE_IPC_EXCHANGE=sockets # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher # Final image FROM base diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 27b74249..b82d23ba 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -155,6 +155,8 @@ async fn prefill( ignore_eos_token: true, // Will not stop even if a eos token is generated }), top_n_tokens: top_n_tokens.unwrap_or(0), + blocks: vec![], + slots: vec![], }) .collect(); @@ -163,6 +165,7 @@ async fn prefill( requests, size: batch_size, max_tokens: batch_size * (sequence_length + decode_length), + max_blocks: 0, }; // Run prefill diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index ca2908c9..01cc43fd 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -130,6 +130,10 @@ message Request { bool prefill_logprobs = 6; /// Return most likely n tokens uint32 top_n_tokens = 7; + /// Paged attention blocks + repeated uint32 blocks = 9; + /// Paged attention slots + repeated uint32 slots = 10; } message Batch { @@ -141,6 +145,8 @@ message Batch { uint32 size = 3; /// Maximum number of tokens this batch will grow to uint32 max_tokens = 4; + /// Maximum number of Paged Attention blocks + uint32 max_blocks = 5; } message CachedBatch { diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 1f3a89a0..9a3892fb 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -153,6 +153,9 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -187,7 +190,8 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: 0, + max_tokens: max_input_length, + max_blocks: 0, }; let request = tonic::Request::new(WarmupRequest { diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 9b4f74d8..94002f55 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -241,12 +241,16 @@ impl Health for ShardedClient { ignore_eos_token: false, }), top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), }; let batch = Batch { id: u64::MAX, requests: vec![liveness_request], size: 1, max_tokens: 2, + max_blocks: 1, }; self.clone().prefill(batch).await?; Ok(()) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs new file mode 100644 index 00000000..7467fd85 --- /dev/null +++ b/router/src/infer/v3/block_allocator.rs @@ -0,0 +1,136 @@ +use std::cmp::min; +use tokio::sync::{mpsc, oneshot}; + +#[derive(Debug, Clone)] +pub(crate) struct BlockAllocation { + pub blocks: Vec, + pub slots: Vec, + block_allocator: BlockAllocator, +} + +impl Drop for BlockAllocation { + fn drop(&mut self) { + self.block_allocator.free(self.blocks.clone()) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct BlockAllocator { + /// Channel to communicate with the background task + block_allocator: mpsc::UnboundedSender, +} + +impl BlockAllocator { + pub(crate) fn new( + max_batch_total_tokens: u32, + block_size: u32, + window_size: Option, + ) -> Self { + // Create channel + let (sender, receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(block_allocator_task( + max_batch_total_tokens / block_size, + block_size, + window_size, + receiver, + )); + + Self { + block_allocator: sender, + } + } + + pub(crate) async fn allocate(&self, tokens: u32) -> Option { + let (response_sender, response_receiver) = oneshot::channel(); + self.block_allocator + .send(BlockAllocatorCommand::Allocate { + tokens, + response_sender, + }) + .unwrap(); + + response_receiver + .await + .unwrap() + .map(|(blocks, slots)| BlockAllocation { + blocks, + slots, + block_allocator: self.clone(), + }) + } + + pub(crate) fn free(&self, blocks: Vec) { + self.block_allocator + .send(BlockAllocatorCommand::Free { blocks }) + .unwrap(); + } +} + +async fn block_allocator_task( + blocks: u32, + block_size: u32, + window_size: Option, + mut receiver: mpsc::UnboundedReceiver, +) { + // Block 0 is reserved for health checks + let mut free_blocks: Vec = (1..blocks).collect(); + while let Some(cmd) = receiver.recv().await { + match cmd { + BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Allocate { + tokens, + response_sender, + } => { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + block_size - 1) / block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + let allocation = if required_blocks > free_blocks.len() as u32 { + None + } else { + let blocks = + free_blocks.split_off(free_blocks.len() - required_blocks as usize); + let mut slots = Vec::with_capacity( + (required_blocks * block_size * repeats as u32) as usize, + ); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * block_size)..((block_id + 1) * block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some((blocks, slots)) + }; + response_sender.send(allocation).unwrap(); + } + } + } +} + +#[derive(Debug)] +enum BlockAllocatorCommand { + Free { + blocks: Vec, + }, + Allocate { + tokens: u32, + response_sender: oneshot::Sender, Vec)>>, + }, +} diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs index 4299baf3..f9effab8 100644 --- a/router/src/infer/v3/mod.rs +++ b/router/src/infer/v3/mod.rs @@ -1,3 +1,4 @@ +mod block_allocator; mod queue; mod scheduler; diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index b926f329..0b66142a 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -1,17 +1,20 @@ -use crate::infer::{InferError, InferStreamResponse}; +use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::infer::InferError; +use crate::infer::InferStreamResponse; use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::min; +use std::cmp::{max, min}; use std::collections::VecDeque; use text_generation_client::v3::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; -use text_generation_client::{ChunksToString, Input}; +use text_generation_client::ChunksToString; +use text_generation_client::Input; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; -use tracing::{info_span, instrument, Span}; +use tracing::{info_span, instrument, Instrument, Span}; /// Queue entry #[derive(Debug)] @@ -28,6 +31,8 @@ pub(crate) struct Entry { pub queue_time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, + /// Block Allocation + pub block_allocation: Option, } /// Request Queue @@ -43,6 +48,7 @@ impl Queue { block_size: u32, window_size: Option, speculate: u32, + max_batch_total_tokens: u32, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -53,12 +59,14 @@ impl Queue { block_size, window_size, speculate, + max_batch_total_tokens, queue_receiver, )); Self { queue_sender } } + /// Append an entry to the queue #[instrument(skip_all)] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state @@ -103,9 +111,16 @@ async fn queue_task( block_size: u32, window_size: Option, speculate: u32, + max_batch_total_tokens: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, block_size, window_size, speculate); + let mut state = State::new( + requires_padding, + block_size, + window_size, + speculate, + max_batch_total_tokens, + ); while let Some(cmd) = receiver.recv().await { match cmd { @@ -120,12 +135,14 @@ async fn queue_task( token_budget, response_sender, span, - } => span.in_scope(|| { - let next_batch = - state.next_batch(min_size, max_size, prefill_token_budget, token_budget); + } => { + let next_batch = state + .next_batch(min_size, max_size, prefill_token_budget, token_budget) + .instrument(span) + .await; response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); - }), + } } } } @@ -142,9 +159,6 @@ struct State { /// Id of the next batch next_batch_id: u64, - /// Whether the model is using padding - requires_padding: bool, - /// Paged Attention block size block_size: u32, @@ -153,6 +167,9 @@ struct State { /// Speculation amount speculate: u32, + + /// Paged Attention Block Allocation + block_allocator: Option, } impl State { @@ -161,15 +178,19 @@ impl State { block_size: u32, window_size: Option, speculate: u32, + max_batch_total_tokens: u32, ) -> Self { + let block_allocator = (!requires_padding) + .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, - requires_padding, block_size, window_size, speculate, + block_allocator, } } @@ -185,7 +206,7 @@ impl State { } // Get the next batch - fn next_batch( + async fn next_batch( &mut self, min_size: Option, max_size: Option, @@ -220,9 +241,10 @@ impl State { let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; + let mut max_blocks = 0; // Pop entries starting from the front of the queue - while let Some((id, mut entry)) = self.entries.pop_front() { + 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -231,43 +253,67 @@ impl State { continue; } - if self.requires_padding { - // We pad to max input length in the Python shards - // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length - } else { - // pad to block size - prefill_tokens += ((entry.request.input_length + self.block_size - 1) - / self.block_size) - * self.block_size; - } + let block_allocation = match &self.block_allocator { + None => { + // We pad to max input length in the Python shards + // We need to take these padding tokens into the equation + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; - if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; - } else { - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + let total_tokens = prefill_tokens + decode_tokens + self.speculate; - // pad to block size - decode_tokens += - ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; - } + if prefill_tokens > prefill_token_budget || total_tokens > token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + None + } + Some(block_allocator) => { + prefill_tokens += entry.request.input_length; + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; + decode_tokens += max_new_tokens; - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget + { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break; + } + + let tokens = entry.request.input_length + + entry.request.stopping_parameters.max_new_tokens + + self.speculate + - 1; + + match block_allocator.allocate(tokens).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + Some(block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + Some(block_allocation) + } + } + } + }; tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry @@ -278,13 +324,23 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); + let (blocks, slots) = match &block_allocation { + None => (Vec::new(), Vec::new()), + Some(block_allocation) => ( + block_allocation.blocks.clone(), + block_allocation.slots.clone(), + ), + }; + + entry.block_allocation = block_allocation; + batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.chunks_to_string(), input_chunks: Some(Input { chunks: entry.request.inputs.clone(), }), + inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), @@ -293,6 +349,8 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, + blocks, + slots, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -335,6 +393,7 @@ impl State { requests: batch_requests, size, max_tokens: (prefill_tokens + decode_tokens), + max_blocks, }; // Increment batch id self.next_batch_id += 1; @@ -438,13 +497,14 @@ mod tests { temp_span: None, queue_time: Instant::now(), batch_time: None, + block_allocation: None, }; (entry, receiver_tx) } - #[test] - fn test_append() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_append() { + let mut state = State::new(false, 1, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -458,23 +518,23 @@ mod tests { assert_eq!(id, 0); } - #[test] - fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_empty() { + let mut state = State::new(false, 1, None, 0, 16); - assert!(state.next_batch(None, None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), None, 1, 1).is_none()); + assert!(state.next_batch(None, None, 1, 1).await.is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); } - #[test] - fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_min_size() { + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -490,7 +550,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), None, 2, 2).is_none()); + assert!(state.next_batch(Some(2), None, 2, 2).await.is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -498,15 +558,15 @@ mod tests { assert_eq!(id, 2); } - #[test] - fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_max_size() { + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert!(entries.get(&0).unwrap().batch_time.is_some()); @@ -518,15 +578,15 @@ mod tests { assert_eq!(state.next_batch_id, 1); } - #[test] - fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_token_budget() { + let mut state = State::new(false, 1, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -539,7 +599,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -553,14 +613,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -568,7 +628,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -601,7 +661,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -617,7 +677,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -642,7 +702,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2); + let queue = Queue::new(false, 1, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -661,7 +721,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 257d191f..ad03dd83 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,7 +39,13 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); + let queue = Queue::new( + requires_padding, + 16, + window_size, + speculate, + max_batch_total_tokens, + ); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic @@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 { temp_span: None, queue_time: Instant::now(), batch_time: None, + block_allocation: None, }); // Notify the background task that we have a new entry in the queue that needs diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py deleted file mode 100644 index c7705fe8..00000000 --- a/server/text_generation_server/models/cache_manager.py +++ /dev/null @@ -1,140 +0,0 @@ -import math -import torch - -from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM - -BLOCK_SIZE: int = 16 -# Will be set in warmup -CACHE_MANAGER: Optional["CacheManager"] = None - - -class CacheManager: - def __init__( - self, - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - repeat_slots: bool, - dtype: torch.dtype, - device: torch.device, - ): - self.block_size = BLOCK_SIZE - self.num_blocks = num_blocks - self.repeat_slots = repeat_slots - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "xpu": - x = 1 - else: - x = self.block_size // element_size - - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") - self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int64 - ).view(num_blocks, self.block_size) - - def allocate( - self, - needed_blocks_slots: List[Tuple[int, int]], - blocks: int, - max_blocks: int, - device: torch.device, - ): - # Get free blocks indices by finding values in mask that are not set to 0 - free_block_indices = self.free_block_mask.nonzero() - if blocks > len(free_block_indices): - raise RuntimeError( - f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" - ) - - # Slice by the number of required blocks - block_indices = free_block_indices[:blocks] - block_indices = block_indices.flatten() - - # Padded block tables - block_tables_tensor = torch.zeros( - (len(needed_blocks_slots), max_blocks), dtype=torch.int32 - ) - - # Allocate paged attention blocks - cumulative_blocks = 0 - slots = [] - block_tables = [] - for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): - # Get allocated blocks for this sequence - allocated_blocks = block_indices[ - cumulative_blocks : cumulative_blocks + needed_blocks - ] - # Get slots for the allocated blocks - all_slots = self.slots[allocated_blocks].flatten() - - # Repeat slots in the case of context sliding window - if needed_slots > len(all_slots) and self.repeat_slots: - repeats = math.ceil(needed_slots / len(all_slots)) - all_slots = all_slots.repeat(repeats) - - allocated_slots = all_slots[:needed_slots] - - slots.append(allocated_slots) - block_tables.append(allocated_blocks.tolist()) - block_tables_tensor[i, :needed_blocks] = allocated_blocks - cumulative_blocks += needed_blocks - - block_tables = block_tables - block_tables_tensor = block_tables_tensor.to(device) - slots = torch.concat(slots).to(device) - - # Allocate the required number of blocks by setting the mask to 0 - self.free_block_mask[block_indices] = 0 - - return block_tables, block_tables_tensor, slots - - def free(self, block_indices: Optional[List[int]]): - if block_indices is not None and block_indices: - # Reset mask - self.free_block_mask[block_indices] = 1 - - -def set_cache_manager( - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - repeat_slots: bool, - dtype: torch.dtype, - device: torch.device, -) -> CacheManager: - global CACHE_MANAGER - if CACHE_MANAGER is not None: - del CACHE_MANAGER - torch.cuda.empty_cache() - - CACHE_MANAGER = CacheManager( - num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device - ) - return CACHE_MANAGER - - -def get_cache_manager() -> CacheManager: - global CACHE_MANAGER - if CACHE_MANAGER is None: - raise RuntimeError("cache manager was not initialized") - - return CACHE_MANAGER diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 31109bc9..764dc6e2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 7967e420..9c32490e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 89ca8b5b..339198a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 59e7bf8b..d399be2f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox( diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 53d3ea42..0a47b1cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d489c3ba..7d3c72a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 1f47550e..74eedc51 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 86d9b4c8..d8c8838c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -25,11 +25,6 @@ from text_generation_server.models.types import ( Generation, GeneratedText, ) -from text_generation_server.models.cache_manager import ( - get_cache_manager, - set_cache_manager, - BLOCK_SIZE, -) from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS import text_generation_server.models.globals as tgi_globals @@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) +BLOCK_SIZE: int = 16 + +# Will be set in init +SLIDING_WINDOW: Optional[int] = None + + +def set_sliding_window(sliding_window: int): + global SLIDING_WINDOW + SLIDING_WINDOW = sliding_window + + +def get_sliding_windows() -> int: + global SLIDING_WINDOW + return SLIDING_WINDOW + @dataclass class FlashCausalLMBatch(Batch): @@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor position_ids: torch.Tensor - speculative_ids: torch.Tensor + speculative_ids: Optional[torch.Tensor] # Flash Attention values # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] # Paged Attention values @@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch): start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor - # List of tuple of ints representing the number of blocks and slots needed by each sequence - needed_blocks_slots: Optional[List[Tuple[int, int]]] - # Set in prefill by the CacheManager # list of length b of list of length s_i // block_size - block_tables: Optional[List[List[int]]] + block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences - block_tables_tensor: Optional[torch.Tensor] + block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: Optional[torch.Tensor] + slots: torch.Tensor max_seqlen: int @@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Number of blocks in this batch - blocks: int + num_blocks: int # Maximum number of blocks max_blocks: int @@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch): id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), - max_tokens=self.blocks * BLOCK_SIZE, + max_tokens=self.num_blocks * BLOCK_SIZE, ) @classmethod @@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch): )["input_ids"] return batch_tokenized_inputs - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) - return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - @classmethod def from_tokenized( cls, @@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + sliding_window = get_sliding_windows() position_ids = [] - speculative_ids = [] cu_seqlen_prefill = [0] - needed_blocks_slots = [] start_slots = [] slot_indices = [] + prefill_cache_indices = [] input_lengths = [] prefix_offsets = [] @@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch): cumulative_max_length = 0 prefill_out_cumulative_length = 0 - blocks = 0 + num_blocks = 0 max_seqlen = 0 max_length = 0 max_blocks = 0 + block_tables = [] + slots = [] + # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) @@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch): speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) - blocks += needed_blocks - needed_blocks_slots.append((needed_blocks, total_tokens)) + + # blocks and slots can be empty (for example in warmup) + if not r.blocks: + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + request_blocks = [ + b for b in range(num_blocks, num_blocks + needed_blocks) + ] + request_slots = [ + s + for b in request_blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_blocks = r.blocks + request_slots = r.slots + + block_tables.append(request_blocks) + slots.extend(request_slots[:total_tokens]) + num_blocks += len(request_blocks) start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( @@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch): ) slot_indices.append(request_slot_indices) + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch): cumulative_length += input_length cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, needed_blocks) + max_blocks = max(max_blocks, len(request_blocks)) max_length = max( max_length, input_length + max_new_tokens + speculative_length ) @@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch): input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) + prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int32, device=device @@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) + slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + block_tables_tensor = block_tables_tensor.to(device) + return cls( batch_id=pb.id, requests=pb.requests, @@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, + prefill_cache_indices=prefill_cache_indices, start_slots=start_slots, slot_indices=slot_indices, - needed_blocks_slots=needed_blocks_slots, - block_tables=None, - block_tables_tensor=None, - slots=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, + num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=None, ) + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": + batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) + return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == 0: @@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] - blocks = 0 + num_blocks = 0 max_blocks = 0 # Cumulative length cumulative_max_length = 0 @@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch): ) request_block_table = self.block_tables[idx] - blocks += len(request_block_table) + num_blocks += len(request_block_table) block_tables.append(request_block_table) start_slots.append(cumulative_max_length) @@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) - block_indices_to_free = [] - # Iterate on all requests - for i, r in enumerate(self.requests): - # Filter requests that are not part of the new batch - if r.id not in requests_idx_mapping.keys(): - block_indices_to_free.extend(self.block_tables[i]) - # Free blocks - get_cache_manager().free(block_indices_to_free) - # Needed to avoid dropping blocks when the batches will go out of scope - self.block_tables = None - # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] @@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, + prefill_cache_indices=None, start_slots=start_slots, slot_indices=slot_indices, - needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, @@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, + num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, ) @@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} - blocks = 0 + num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 @@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch): for b in batches: total_batch_size += len(b) total_slots += len(b.slots) - blocks += b.blocks + num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) @@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch): else None ) - # Needed to avoid dropping blocks when the batches will go out of scope - for b in batches: - b.block_tables = None - del b - return cls( batch_id=batches[0].batch_id, requests=requests, @@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, + prefill_cache_indices=None, start_slots=start_slots, slot_indices=slot_indices, - needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, @@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, + num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, ) - def __del__(self): - if self.block_tables is not None and self.block_tables: - # Free blocks - get_cache_manager().free( - list(itertools.chain.from_iterable(self.block_tables)) - ) - def __len__(self): return len(self.requests) @@ -702,6 +732,7 @@ class FlashCausalLM(Model): self.head_size = head_size self.cuda_graphs = {} + self.kv_cache = [] super(FlashCausalLM, self).__init__( model=model, @@ -718,6 +749,43 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def max_past(self) -> int: + return getattr(self.model, "max_past", None) + + def init_kv_cache( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.kv_cache = [] + empty_cache() + + element_size = torch.tensor([], dtype=dtype).element_size() + if SYSTEM == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) @@ -728,12 +796,11 @@ class FlashCausalLM(Model): .repeat(bs) .reshape((bs, max_bt)) ) - kv_cache = get_cache_manager().kv_cache self.cuda_graphs[bs] = { "input_ids": input_ids, "position_ids": position_ids, - "kv_cache": kv_cache, + "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, "input_lengths": input_lengths, @@ -747,11 +814,12 @@ class FlashCausalLM(Model): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, - kv_cache=kv_cache, + kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, + prefill_cache_indices=None, lm_head_indices=None, ) torch.cuda.synchronize() @@ -761,11 +829,12 @@ class FlashCausalLM(Model): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, - kv_cache=kv_cache, + kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, + prefill_cache_indices=None, lm_head_indices=None, ) self.cuda_graphs[bs]["logits"] = logits @@ -777,17 +846,16 @@ class FlashCausalLM(Model): empty_cache() try: - cache_manager = set_cache_manager( - batch.blocks, + self.init_kv_cache( + batch.num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.sliding_window is not None, self.dtype, self.device, ) max_bt = batch.max_blocks - max_s = max_bt * get_cache_manager().block_size + max_s = max_bt * BLOCK_SIZE if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) @@ -811,19 +879,17 @@ class FlashCausalLM(Model): num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) - # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + cache_manager.num_blocks + # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + + batch.num_blocks ) del batch - del cache_manager - set_cache_manager( + self.init_kv_cache( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.sliding_window is not None, self.dtype, self.device, ) @@ -889,7 +955,6 @@ class FlashCausalLM(Model): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - kv_cache = get_cache_manager().kv_cache # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) @@ -901,12 +966,13 @@ class FlashCausalLM(Model): cu_seqlen_prefill=torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ), - kv_cache=get_cache_manager().kv_cache, + kv_cache=self.kv_cache, block_tables=None, input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None, + prefill_cache_indices=None, ) def forward( @@ -917,7 +983,7 @@ class FlashCausalLM(Model): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor @@ -956,13 +1022,19 @@ class FlashCausalLM(Model): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) if sorted_padded_bs: @@ -972,7 +1044,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - return self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -981,8 +1053,12 @@ class FlashCausalLM(Model): slots=slots, input_lengths=input_lengths, max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -1015,24 +1091,7 @@ class FlashCausalLM(Model): prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - if batch.needed_blocks_slots: - # Allocate blocks to this batch - block_tables, block_tables_tensor, slots = get_cache_manager().allocate( - batch.needed_blocks_slots, - batch.blocks, - batch.max_blocks, - batch.input_ids.device, - ) - batch.needed_blocks_slots = None - batch.block_tables = block_tables - batch.block_tables_tensor = block_tables_tensor - batch.slots = slots - - try: - out, speculative_logits = self.forward(batch) - except Exception as e: - del batch - raise e + out, speculative_logits = self.forward(batch) if prefill: next_token_logits = ( @@ -1327,7 +1386,6 @@ class FlashCausalLM(Model): batch.all_input_ids[i] = all_input_ids if stopped: - del batch # No need to return a batch if we know that all requests stopped forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index e6125e29..081c2e2c 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,308 +1,24 @@ -import math import torch import torch.distributed -import numpy as np - -from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig -from typing import Optional, Tuple, Type +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, Tuple -from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE -from text_generation_server.models.cache_manager import ( - get_cache_manager, -) +from text_generation_server.models.flash_causal_lm import set_sliding_window from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, MistralConfig, ) -from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - HeterogeneousNextTokenChooser, - StoppingCriteria, ) - -tracer = trace.get_tracer(__name__) - -# Will be set in init -SLIDING_WINDOW: Optional[int] = None -SLIDING_WINDOW_BLOCKS: Optional[int] = None from text_generation_server.utils.import_utils import SYSTEM -MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None - - -def set_sliding_window(sliding_window: int, sliding_window_blocks: int): - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - SLIDING_WINDOW = sliding_window - SLIDING_WINDOW_BLOCKS = sliding_window_blocks - - -def get_sliding_windows() -> Tuple[int, int]: - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS - - -# Adds windowing logic to FlashCausalLMBatch -@dataclass -class FlashMistralBatch(FlashCausalLMBatch): - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] = None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) - return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - - @classmethod - def from_tokenized( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - batch_tokenized_inputs, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - sliding_window, sliding_window_blocks = get_sliding_windows() - - position_ids = [] - cu_seqlen_prefill = [0] - needed_blocks_slots = [] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - - # Cumulative length - cumulative_length = 0 - cumulative_max_length = 0 - prefill_out_cumulative_length = 0 - - blocks = 0 - max_seqlen = 0 - max_length = 0 - max_blocks = 0 - - # Parse batch - for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) - ): - # request id -> idx in list mapping - requests_idx_mapping[r.id] = i - - tokenized_input = tokenized_input[-r.truncate :] - if ( - tokenized_input[0] == tokenizer.bos_token_id - and tokenized_input[1] == tokenizer.bos_token_id - ): - tokenized_input = tokenized_input[1:] - - input_length = len(tokenized_input) - input_lengths.append(input_length) - - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) - - all_input_ids.append(tokenized_input) - - # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - - next_token_chooser_parameters.append(r.parameters) - - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - max_new_tokens = stopping_criteria.max_new_tokens - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - - # Paged attention - # Remove one as the first token des not have a past - speculative_length = get_speculate() - total_tokens = input_length + max_new_tokens - 1 + speculative_length - - # Needed blocks can not go over SLIDING_WINDOW_BLOCKS - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) - if sliding_window_blocks is not None: - needed_blocks = min(needed_blocks, sliding_window_blocks) - blocks += needed_blocks - - needed_blocks_slots.append((needed_blocks, total_tokens)) - start_slots.append(cumulative_max_length) - - request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - - # Update - cumulative_length += input_length - cumulative_max_length += total_tokens - max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, needed_blocks) - max_length = max( - max_length, input_length + max_new_tokens + speculative_length - ) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device, tokenizer - ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Padded all_input_ids_tensor - all_input_ids_tensor = np.zeros( - (len(all_input_ids), max_length), dtype=np.int64 - ) - for i, input_ids in enumerate(all_input_ids): - all_input_ids_tensor[i, : len(input_ids)] = input_ids - - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) - - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - start_slots=start_slots, - slot_indices=slot_indices, - needed_blocks_slots=needed_blocks_slots, - block_tables=None, - block_tables_tensor=None, - slots=None, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, - max_blocks=max_blocks, - prefill_cache_indices=prefill_cache_indices, - speculative_ids=None, - ) +tracer = trace.get_tracer(__name__) class BaseFlashMistral(FlashCausalLM): @@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM): # Set context windows if getattr(config, "sliding_window", None) is not None: - set_sliding_window( - config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) - ) + set_sliding_window(config.sliding_window) else: config.sliding_window = None @@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM): model.model.head_size, ) - def max_past(self) -> int: - return self.model.max_past - - @property - def batch_type(self) -> Type[FlashMistralBatch]: - return FlashMistralBatch - - def tunableop_warmup(self, seqlen: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - kv_cache = get_cache_manager().kv_cache - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - - # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ), - kv_cache=get_cache_manager().kv_cache, - block_tables=None, - input_lengths=input_lengths, - slots=slots, - max_s=seqlen, - lm_head_indices=None, - prefill_cache_indices=None, - ) - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) - ) - kv_cache = get_cache_manager().kv_cache - - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths, - } - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs]["graph"] = graph - - torch.cuda.synchronize() - # Run once outside to warmup - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits - torch.cuda.synchronize() - - def forward( - self, batch: FlashMistralBatch - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Model Forward - if batch.speculative_ids is not None: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - speculative_ids = batch.speculative_ids - - B, speculative_length = speculative_ids.shape - new_length = speculative_length + 1 - new_input_ids = torch.cat( - [input_ids.unsqueeze(-1), speculative_ids], dim=1 - ).reshape(-1) - arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) - arange_int = arange.to(dtype=torch.int32) - new_position_ids = ( - position_ids.unsqueeze(-1).expand(B, new_length) + arange - ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int - ).view(-1) - - # Add Copy the block tables for all members - block_tables = ( - block_tables.unsqueeze(1) - .expand(B, new_length, -1) - .reshape(B * new_length, -1) - .contiguous() - ) - max_s = max_s + speculative_length - - input_ids = new_input_ids - position_ids = new_position_ids - else: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - - bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - - if cu_seqlen_prefill is not None or cuda_graph is None: - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(-1) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None - ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits - class FlashMistral(BaseFlashMistral): def __init__( diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 59064b30..75285863 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -7,7 +7,6 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig from typing import Optional -from text_generation_server.models.cache_manager import BLOCK_SIZE from text_generation_server.models.flash_mistral import ( BaseFlashMistral, set_sliding_window, @@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral): # Set context windows if config.sliding_window is not None: - set_sliding_window( - config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) - ) + set_sliding_window(config.sliding_window) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index dc5d49be..5533c9d9 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -6,7 +6,6 @@ from typing import Optional from transformers.models.gpt2 import GPT2TokenizerFast -from text_generation_server.models.cache_manager import BLOCK_SIZE from text_generation_server.models.flash_mistral import ( BaseFlashMistral, set_sliding_window, @@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral): # Set context windows if config.sliding_window is not None: - set_sliding_window( - config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) - ) + set_sliding_window(config.sliding_window) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f0db89b2..92d79070 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_mistral import ( BaseFlashMistral, - FlashMistralBatch, -) -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch -from text_generation_server.models.cache_manager import ( - get_cache_manager, ) tracer = trace.get_tracer(__name__) @@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image: return image -class VlmCausalLMBatch(FlashMistralBatch): +class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor @@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor From 8ee07f0eaedc2b839c192451c1137a8877b5a5d6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 5 Jun 2024 14:41:34 +0200 Subject: [PATCH 030/340] Fixing rocm. (#2021) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../layers/attention/rocm.py | 121 ++++-------------- 1 file changed, 26 insertions(+), 95 deletions(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 2d3601c8..535810aa 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -126,40 +126,34 @@ if ENGINE != "triton": import flash_attn_2_cuda logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") - except ImportError: - try: - import flash_attn_cuda + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: - ENGINE = "v1" - logger.info("ROCm: using Flash Attention 1") - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - - for idx in range(torch.cuda.device_count()): - name = torch.cuda.get_device_name(idx) - if "MI210" not in name and "MI250" not in name: - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - raise ImportError( - f"AMD GPU with ROCm capability {major} {minor} is not supported" - ) from e + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e -SUPPORTS_WINDOWING = ENGINE != "v1" +SUPPORTS_WINDOWING = False if ENGINE == "ck": def attention( @@ -186,17 +180,12 @@ if ENGINE == "ck": out, cu_seqlens, cu_seqlens, - None, - None, - None, max_s, max_s, 0.0, softmax_scale, False, causal, - window_size_left, - 0, False, None, ) @@ -234,62 +223,4 @@ elif ENGINE == "triton": return output else: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - return flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - 0, - None, - ) + raise RuntimeError(f"Unknown attention engine {ENGINE}") From af9d60c985f847595e2bf123ac534fdcef3a5728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 5 Jun 2024 14:49:15 +0200 Subject: [PATCH 031/340] Fix GPTQWeight import (#2020) # What does this PR do? Fix stray import. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/layers/gptq/exllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/gptq/exllama.py b/server/text_generation_server/layers/gptq/exllama.py index 4875af38..f27666b7 100644 --- a/server/text_generation_server/layers/gptq/exllama.py +++ b/server/text_generation_server/layers/gptq/exllama.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.weights import GPTQWeight +from text_generation_server.layers.gptq import GPTQWeight import torch from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params From 2c1ff79d3824722fa2bbcd1c4b28a9519039c4af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Marafioti?= Date: Wed, 5 Jun 2024 14:51:07 +0200 Subject: [PATCH 032/340] Update __version__ on __init__.py to 0.7.0 (#2017) There was a new release of the python client with version upped to 0.7.0 on pip and on the pyproject.toml, but it wasn't changed on the __init__.py so when one does: ```python import text_generation print(text_generation.__version__) ``` It still outputs "0.6.0" # What does this PR do? Fixes # (issue) ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- clients/python/text_generation/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index a8e67071..d7a09c9e 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.6.0" +__version__ = "0.7.0" DEPRECATION_WARNING = ( "`text_generation` clients are deprecated and will be removed in the near future. " From 346f77f8bab7247a111b6f83983484d447dd2c96 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Jun 2024 10:33:01 +0200 Subject: [PATCH 033/340] Less cache misses on cargo build. --- Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Dockerfile b/Dockerfile index 6492d723..d2cde6a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,6 +23,9 @@ COPY --from=planner /usr/src/recipe.json recipe.json COPY Cargo.lock Cargo.lock RUN cargo chef cook --release --recipe-path recipe.json +ARG GIT_SHA +ARG DOCKER_LABEL + COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto From 9c75591c119580396517c257eb9376185bdc7958 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Jun 2024 10:33:55 +0200 Subject: [PATCH 034/340] Revert "Less cache misses on cargo build." This reverts commit 5aec4154c2b7a49c80f256422c523410d9f96ec3. --- Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index d2cde6a4..531e02df 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,6 +13,9 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +ARG GIT_SHA +ARG DOCKER_LABEL + RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -23,9 +26,6 @@ COPY --from=planner /usr/src/recipe.json recipe.json COPY Cargo.lock Cargo.lock RUN cargo chef cook --release --recipe-path recipe.json -ARG GIT_SHA -ARG DOCKER_LABEL - COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto From 77ac0f364b0289b03600911dfc81bec5204ef8c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 5 Jun 2024 08:14:40 +0000 Subject: [PATCH 035/340] Add support for Marlin-quantized models This change adds support for Marlin-quantized models. Marlin is an FP16xINT4 matmul kernel, which provides good speedups decoding batches of 16-32 tokens. It supports quantized models with symmetric quantization, groupsize -1 or 128, and 4-bit. Tested with: - Llama 2 - Llama 3 - Phi 3 --- docs/source/basic_tutorials/launcher.md | 1 + .../test_flash_llama_marlin.json | 89 +++++ .../test_flash_llama_marlin_all_params.json | 89 +++++ .../test_flash_llama_marlin_load.json | 358 ++++++++++++++++++ .../models/test_flash_llama_marlin.py | 63 +++ launcher/src/main.rs | 5 + server/Makefile | 1 + server/Makefile-marlin | 13 + server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/linear.py | 8 + .../text_generation_server/layers/marlin.py | 96 +++++ .../layers/tensor_parallel.py | 2 +- .../text_generation_server/models/__init__.py | 2 +- .../custom_modeling/flash_dbrx_modeling.py | 5 + .../custom_modeling/flash_gemma_modeling.py | 2 +- .../custom_modeling/flash_gpt2_modeling.py | 4 + .../custom_modeling/flash_mixtral_modeling.py | 2 +- .../custom_modeling/flash_phi_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 2 +- .../flash_santacoder_modeling.py | 4 + .../models/flash_gpt2.py | 2 +- .../text_generation_server/utils/weights.py | 35 ++ 22 files changed, 779 insertions(+), 7 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_load.json create mode 100644 integration-tests/models/test_flash_llama_marlin.py create mode 100644 server/Makefile-marlin create mode 100644 server/text_generation_server/layers/marlin.py diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index acab822e..9246093e 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -64,6 +64,7 @@ Options: - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels + - marlin: 4 bit quantization. Requires a specific Marlin quantized model: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin.json new file mode 100644 index 00000000..47849a3f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -12.390625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.0625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0507812, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -2.3007812, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.0449219, + "special": false, + "text": "I" + }, + { + "id": 505, + "logprob": -1.3242188, + "special": false, + "text": " have" + }, + { + "id": 263, + "logprob": -0.2076416, + "special": false, + "text": " a" + }, + { + "id": 1243, + "logprob": -2.0273438, + "special": false, + "text": " test" + }, + { + "id": 2009, + "logprob": -0.6845703, + "special": false, + "text": " request" + }, + { + "id": 515, + "logprob": -1.1748047, + "special": false, + "text": " from" + }, + { + "id": 263, + "logprob": -1.0644531, + "special": false, + "text": " a" + }, + { + "id": 1404, + "logprob": -1.5224609, + "special": false, + "text": " user" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI have a test request from a user" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_all_params.json new file mode 100644 index 00000000..bda2393e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -12.390625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.0625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -1.2607422, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 6527, + "logprob": -0.11450195, + "special": false, + "text": " Could" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " not" + }, + { + "id": 4511, + "logprob": -0.2286377, + "special": false, + "text": " connect" + }, + { + "id": 304, + "logprob": 0.0, + "special": false, + "text": " to" + }, + { + "id": 1923, + "logprob": -1.2568359, + "special": false, + "text": " server" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.15905762, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -0.21618652, + "special": false, + "text": "I" + } + ], + "top_tokens": null + }, + "generated_text": "Test request failed: Could not connect to server\n\nI" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_load.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_load.json new file mode 100644 index 00000000..44c26efb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -12.390625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.0625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0507812, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -2.3007812, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.0449219, + "special": false, + "text": "I" + }, + { + "id": 505, + "logprob": -1.3242188, + "special": false, + "text": " have" + }, + { + "id": 263, + "logprob": -0.2076416, + "special": false, + "text": " a" + }, + { + "id": 1243, + "logprob": -2.0273438, + "special": false, + "text": " test" + }, + { + "id": 2009, + "logprob": -0.6845703, + "special": false, + "text": " request" + }, + { + "id": 515, + "logprob": -1.1748047, + "special": false, + "text": " from" + }, + { + "id": 263, + "logprob": -1.0595703, + "special": false, + "text": " a" + }, + { + "id": 1404, + "logprob": -1.5224609, + "special": false, + "text": " user" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI have a test request from a user" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -12.390625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.0625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0507812, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -2.3007812, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.0449219, + "special": false, + "text": "I" + }, + { + "id": 505, + "logprob": -1.3242188, + "special": false, + "text": " have" + }, + { + "id": 263, + "logprob": -0.2076416, + "special": false, + "text": " a" + }, + { + "id": 1243, + "logprob": -2.0273438, + "special": false, + "text": " test" + }, + { + "id": 2009, + "logprob": -0.6845703, + "special": false, + "text": " request" + }, + { + "id": 515, + "logprob": -1.1748047, + "special": false, + "text": " from" + }, + { + "id": 263, + "logprob": -1.0595703, + "special": false, + "text": " a" + }, + { + "id": 1404, + "logprob": -1.5224609, + "special": false, + "text": " user" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI have a test request from a user" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -12.390625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.0625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0507812, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -2.3007812, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.0449219, + "special": false, + "text": "I" + }, + { + "id": 505, + "logprob": -1.3242188, + "special": false, + "text": " have" + }, + { + "id": 263, + "logprob": -0.2076416, + "special": false, + "text": " a" + }, + { + "id": 1243, + "logprob": -2.0273438, + "special": false, + "text": " test" + }, + { + "id": 2009, + "logprob": -0.6845703, + "special": false, + "text": " request" + }, + { + "id": 515, + "logprob": -1.1748047, + "special": false, + "text": " from" + }, + { + "id": 263, + "logprob": -1.0595703, + "special": false, + "text": " a" + }, + { + "id": 1404, + "logprob": -1.5224609, + "special": false, + "text": " user" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI have a test request from a user" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -12.390625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.0625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0507812, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -2.3007812, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.0449219, + "special": false, + "text": "I" + }, + { + "id": 505, + "logprob": -1.3242188, + "special": false, + "text": " have" + }, + { + "id": 263, + "logprob": -0.2076416, + "special": false, + "text": " a" + }, + { + "id": 1243, + "logprob": -2.0273438, + "special": false, + "text": " test" + }, + { + "id": 2009, + "logprob": -0.6845703, + "special": false, + "text": " request" + }, + { + "id": 515, + "logprob": -1.1748047, + "special": false, + "text": " from" + }, + { + "id": 263, + "logprob": -1.0595703, + "special": false, + "text": " a" + }, + { + "id": 1404, + "logprob": -1.5224609, + "special": false, + "text": " user" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI have a test request from a user" + } +] diff --git a/integration-tests/models/test_flash_llama_marlin.py b/integration-tests/models/test_flash_llama_marlin.py new file mode 100644 index 00000000..e7c5ccbd --- /dev/null +++ b/integration-tests/models/test_flash_llama_marlin.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_marlin_handle(launcher): + with launcher( + "neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_marlin(flash_llama_marlin_handle): + await flash_llama_marlin_handle.health(300) + return flash_llama_marlin_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin_load( + flash_llama_marlin, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_marlin, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ce2af08f..839e8077 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -66,6 +66,8 @@ enum Quantization { /// triton kernel (wider support) when it's not. /// AWQ has faster kernels. Gptq, + /// 4 bit quantization. Requires a specific Marlin quantized model: . + Marlin, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// but it is known that the model will be much slower to run than the native f16. #[deprecated( @@ -107,6 +109,9 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Marlin => { + write!(f, "marlin") + } Quantization::Awq => { write!(f, "awq") } diff --git a/server/Makefile b/server/Makefile index 5257b876..f2a45cc0 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,6 +3,7 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-eetq +include Makefile-marlin include Makefile-selective-scan unit-tests: diff --git a/server/Makefile-marlin b/server/Makefile-marlin new file mode 100644 index 00000000..8b4333e8 --- /dev/null +++ b/server/Makefile-marlin @@ -0,0 +1,13 @@ +marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c + +marlin: + # Clone marlin + pip install packaging + git clone https://github.com/IST-DASLab/marlin.git marlin + +build-marlin: marlin + cd marlin && git fetch && git checkout $(marlin_commit) + cd marlin && python setup.py build + +install-marlin: build-marlin + cd marlin && python setup.py install diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 16375ecd..68b429d0 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -21,6 +21,7 @@ class Quantization(str, Enum): eetq = "eetq" exl2 = "exl2" fp8 = "fp8" + marlin = "marlin" class Dtype(str, Enum): diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index ff99388e..3537b62d 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -222,6 +222,14 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" ) + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinLinear, MarlinWeight + + if not isinstance(weight, MarlinWeight): + raise NotImplementedError( + f"The passed weight is not `marlin` compatible, loader needs to be updated." + ) + linear = MarlinLinear(B=weight.B, s=weight.s, bias=bias) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py new file mode 100644 index 00000000..a860d84b --- /dev/null +++ b/server/text_generation_server/layers/marlin.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +try: + import marlin +except ImportError: + marlin = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + +MARLIN_TILE_SIZE = 16 + + +@dataclass +class MarlinWeight: + """ + Marlin weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + s (torch.Tensor): float16 scales. + """ + + B: torch.Tensor + s: torch.Tensor + + +class MarlinLinear(nn.Module): + def __init__( + self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor] + ): + super().__init__() + + if not has_sm_8_0: + raise NotImplementedError( + "Using quantized marlin models requires CUDA capability 8.0 or later" + ) + + if marlin is None: + raise NotImplementedError( + "You do not seem to have marlin installed, either install it (cd server && make install-marlin)" + ) + + assert B.dtype == torch.int32 + assert s.dtype == torch.float16 + + in_features = B.shape[0] * MARLIN_TILE_SIZE + out_features = s.shape[1] + assert ( + in_features % 128 == 0 + ), f"Number of input features ({in_features}) not divisable by 128" + assert ( + out_features % 256 == 0 + ), f"Number of output features ({out_features}) not divisable by 256" + + group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0] + assert group_size in { + -1, + 128, + }, f"Group size must be -1 or 128, was {group_size}" + + self.register_buffer("B", B) + self.register_buffer("s", s) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 128 * 16, dtype=torch.int, device=B.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin is not None + C = torch.empty( + A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device + ) + marlin.mul( + A.view((-1, A.shape[-1])), + self.B, + C.view((-1, C.shape[-1])), + self.s, + self.workspace, + ) + + if self.bias is not None: + C += self.bias + + return C diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index afaaa1b8..192c2b42 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -64,7 +64,7 @@ class TensorParallelHead(SuperLayer): should_gather = False # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq"]: + if config.quantize in ["gptq", "awq", "eetq", "marlin"]: quantize = None # See above, exl2 LM head can be quantized or not. elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index dbe49039..ba353c11 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -260,7 +260,7 @@ def get_model( ) -> Model: global FLASH_ATTENTION if dtype is None: - if quantize in ["awq", "exl2", "gptq"]: + if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 else: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9c32490e..63ce6543 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -271,6 +271,11 @@ def _load_gqa(config, prefix: str, weights): groupsize=groupsize, use_exllama=use_exllama, ) + elif config.quantize == "marlin": + # NOTE: at the time marlin support was added, the only model that + # exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2), + # but it requires manual concatenation of weight files. + raise RuntimeError("dbrx models with marlin quantization are not yet supported") else: qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") q = qkv_slice[q_start:q_stop] diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 339198a7..04d05cd6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -145,7 +145,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 52a7c283..0178c911 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -46,6 +46,10 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads): prefix, weights, ) + elif config.quantize == "marlin": + raise RuntimeError( + "GPT-2 models with marlin quantization are not yet supported" + ) else: return _load_qkv(config, prefix, weights, head_size, num_heads) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 37cd6f3b..3900bf73 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -139,7 +139,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 0a47b1cc..6dda4b2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -89,7 +89,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 2b035c2e..b1de58b2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -46,7 +46,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 74eedc51..4fa6516e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -29,6 +29,10 @@ def load_multi_mqa( return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) + elif config.quantize == "marlin": + raise RuntimeError( + "santacoder models with marlin quantization are not yet supported" + ) else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 0067a806..ef129e92 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -58,7 +58,7 @@ class FlashGPT2(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) prefix = "" diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 71d67d82..d02178d6 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -202,6 +202,12 @@ class Weights: groupsize=groupsize, use_exllama=False, ) + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeight + + B = self._get_qweight(f"{prefix}.B", blocks) + s = self._get_qweight(f"{prefix}.s", blocks) + weight = MarlinWeight(B=B, s=s) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] @@ -316,9 +322,25 @@ class Weights: groupsize=groupsize, use_exllama=use_exllama, ) + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeight + + try: + B = torch.cat( + [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight, make sure the model is already quantized" + ) + s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1) + + weight = MarlinWeight(B=B, s=s) + else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) + return weight def get_tensor_shard(self, var, dim): @@ -481,6 +503,19 @@ class Weights: groupsize=groupsize, use_exllama=use_exllama, ) + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeight + + try: + B = self.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + s = self.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight From e6d8d2e50f69eb8c1bd203380a4353076c5d04ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 6 Jun 2024 11:51:52 +0000 Subject: [PATCH 036/340] marlin: support tp>1 when group_size==-1 --- server/text_generation_server/utils/weights.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index d02178d6..557656e7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -513,7 +513,13 @@ class Weights: "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - s = self.get_sharded(f"{prefix}.s", dim=0) + num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when group_size == -1. share + # scales between all shards in this case. + s = self.get_tensor(f"{prefix}.s") + else: + s = self.get_sharded(f"{prefix}.s", dim=0) weight = MarlinWeight(B=B, s=s) else: From 7aaec2a54266f4ec66ccb39796b4afec883a8f23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 6 Jun 2024 11:25:56 +0000 Subject: [PATCH 037/340] marlin: improve build --- server/Makefile-marlin | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/server/Makefile-marlin b/server/Makefile-marlin index 8b4333e8..816546af 100644 --- a/server/Makefile-marlin +++ b/server/Makefile-marlin @@ -1,13 +1,11 @@ marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c -marlin: - # Clone marlin - pip install packaging - git clone https://github.com/IST-DASLab/marlin.git marlin - -build-marlin: marlin - cd marlin && git fetch && git checkout $(marlin_commit) - cd marlin && python setup.py build +build-marlin: + if [ ! -d 'marlin' ]; then \ + pip install -U ninja packaging --no-cache-dir && \ + git clone https://github.com/IST-DASLab/marlin.git marlin; \ + fi + cd marlin && git fetch && git checkout $(marlin_commit) && python setup.py build install-marlin: build-marlin - cd marlin && python setup.py install + cd marlin && git fetch && git checkout $(marlin_commit) && pip install -e . From 0494677284cb27dc727d86b142851846b86af7f7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Jun 2024 18:51:42 +0200 Subject: [PATCH 038/340] Internal runner ? (#2023) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- Dockerfile | 3 --- Dockerfile_amd | 6 +++--- Dockerfile_intel | 6 +++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 531e02df..6492d723 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,9 +13,6 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder -ARG GIT_SHA -ARG DOCKER_LABEL - RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ diff --git a/Dockerfile_amd b/Dockerfile_amd index b0d181ea..c79bc03c 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -15,9 +15,6 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder -ARG GIT_SHA -ARG DOCKER_LABEL - RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -27,6 +24,9 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ COPY --from=planner /usr/src/recipe.json recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json +ARG GIT_SHA +ARG DOCKER_LABEL + COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto diff --git a/Dockerfile_intel b/Dockerfile_intel index 0a700003..ee963928 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -14,9 +14,6 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder -ARG GIT_SHA -ARG DOCKER_LABEL - RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -26,6 +23,9 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ COPY --from=planner /usr/src/recipe.json recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json +ARG GIT_SHA +ARG DOCKER_LABEL + COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto From 63cd798a194cd5189f44c414418a6dba458476b0 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 7 Jun 2024 01:12:57 +0800 Subject: [PATCH 039/340] Xpu gqa (#2013) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. Signed-off-by: Wang, Yi A --- Dockerfile_intel | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index ee963928..cb0e84bb 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -48,7 +48,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && apt install -y intel-basekit xpu-smi +RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ @@ -57,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src -RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl -RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl +RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope # Install server COPY proto proto @@ -76,6 +76,10 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV CCL_ZE_IPC_EXCHANGE=sockets +ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest +ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include + +RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark From 93663b4567727234cb72f25b919fd23eee0cd0e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 31 May 2024 11:51:42 +0000 Subject: [PATCH 040/340] server: use chunked inputs The router will now send the input as chunks besides as a single string. This change modifies the server to process chunked input rather than strings. This also allows us to remove the image extraction code from the server. --- server/tests/models/test_bloom.py | 1 + server/tests/models/test_causal_lm.py | 1 + server/tests/models/test_santacoder.py | 8 +++ server/tests/models/test_seq2seq_lm.py | 1 + .../models/causal_lm.py | 4 +- .../models/flash_causal_lm.py | 9 ++- .../models/galactica.py | 5 +- .../models/idefics_causal_lm.py | 21 ++++--- server/text_generation_server/models/mamba.py | 3 +- .../models/pali_gemma.py | 39 +++++------- .../models/seq2seq_lm.py | 3 +- .../models/vlm_causal_lm.py | 60 ++++--------------- server/text_generation_server/utils/chunks.py | 27 +++++++++ 13 files changed, 94 insertions(+), 88 deletions(-) create mode 100644 server/text_generation_server/utils/chunks.py diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 66df708a..32ee6686 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 250fa354..6e6463bc 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1e40e766..cb2622d9 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, @@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + input_chunks=generate_pb2.Input( + chunks=[ + generate_pb2.InputChunk( + text="defworld" + ) + ] + ), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 735ab5eb..943c3b08 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 81a02163..e896c831 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -86,7 +87,8 @@ class CausalLMBatch(Batch): max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) + next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d8c8838c..acf77b09 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,9 +11,10 @@ from loguru import logger from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Dict +from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens @@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch): ) @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer): + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer + ): batch_inputs = [] max_truncation = 0 for r in requests: - batch_inputs.append(r.inputs) + batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 4656fd45..d0f2b915 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.chunks import concat_text_chunks # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py @@ -91,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append(escape_custom_split_sequence(r.inputs)) + inputs.append( + escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) + ) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index e78a9655..f507d669 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -1,4 +1,5 @@ -import torch +from io import BytesIO +from PIL import Image import torch import time @@ -21,11 +22,6 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from text_generation_server.models.vlm_causal_lm import split - -import re - -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") tracer = trace.get_tracer(__name__) @@ -109,7 +105,7 @@ class IdeficsCausalLMBatch(Batch): max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(r.input_chunks.chunks) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) @@ -128,8 +124,15 @@ class IdeficsCausalLMBatch(Batch): for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL prompt = [] - for chunk in split(inp): - prompt.append(chunk["content"]) + for chunk in inp: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + prompt.append(chunk.text) + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + prompt.append(image) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") prompts.append(prompt) # The processor replaces the call to tokenizer, and diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index d9f90590..3133a137 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -27,6 +27,7 @@ from text_generation_server.models.types import ( Generation, GeneratedText, ) +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -139,7 +140,7 @@ class MambaBatch(Batch): max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index d94b9526..e883ce02 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -1,55 +1,48 @@ +from io import BytesIO +from PIL import Image import torch import torch.distributed from opentelemetry import trace -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from text_generation_server.models.vlm_causal_lm import ( VlmCausalLM, VlmCausalLMBatch, image_text_replacement, - load_data_uri, - split, ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) -from transformers import AutoProcessor, AutoConfig, AutoImageProcessor +from transformers import AutoProcessor, AutoConfig + +from text_generation_server.pb.generate_pb2 import Request tracer = trace.get_tracer(__name__) class PaliGemmaBatch(VlmCausalLMBatch): @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs( + cls, requests: Iterable[Request], tokenizer, processor, config + ): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: - chunks = split(r.inputs) full_text = "" image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += "" + chunk["content"] + "\n" - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += "" + chunk.text + "\n" + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) # TODO do_convert_RGB should be on by default ? image = image.convert("RGB") image_input = processor.image_processor(image, return_tensors="pt") full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") + raise RuntimeError(f"Invalid chunk type {chunk_type}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 6a0c812f..3bd09556 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -6,6 +6,7 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -93,7 +94,7 @@ class Seq2SeqLMBatch(Batch): padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) next_token_choosers.append( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 92d79070..59a6fab1 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,12 +1,9 @@ -import re import torch -import math from PIL import Image from io import BytesIO -import base64 from opentelemetry import trace -from typing import Optional, Tuple, List, Type, Dict +from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution @@ -18,25 +15,6 @@ from text_generation_server.models.flash_mistral import ( tracer = trace.get_tracer(__name__) -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") - - -def split(string) -> List[Dict[str, str]]: - parts = [] - cursor = 0 - for pattern in IMAGES.finditer(string): - start = pattern.start() - if start != cursor: - parts.append({"type": "text", "content": string[cursor:start]}) - - parts.append({"type": "image", "content": pattern.group(1)}) - cursor = pattern.end() - - if cursor != len(string): - parts.append({"type": "text", "content": string[cursor:]}) - - return parts - def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features -def load_data_uri(image_uri: str) -> Image.Image: - image_uri = image_uri.split(",")[-1] - content = base64.b64decode(image_uri) - image = Image.open(BytesIO(content)) - return image - - class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] @@ -159,35 +130,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: - chunks = split(r.inputs) full_text = "" image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += chunk["content"] - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) image_input = processor.image_processor(image, return_tensors="pt") full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") + raise RuntimeError(f"Invalid chunk type {chunk_type}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) diff --git a/server/text_generation_server/utils/chunks.py b/server/text_generation_server/utils/chunks.py new file mode 100644 index 00000000..73962ea3 --- /dev/null +++ b/server/text_generation_server/utils/chunks.py @@ -0,0 +1,27 @@ +from typing import Iterable + +from loguru import logger + +from text_generation_server.pb import generate_pb2 + + +def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str: + """ + Concatenate text in text chunks. Non-text chunks are dropped. + """ + text = None + for chunk in chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + if text is None: + text = chunk.text + else: + raise NotImplementedError("Request contained more than one text chunk") + else: + # We cannot reject this, e.g. warmup sends an image chunk. + logger.debug(f"Encountered non-text chunk type {chunk_type}") + + if text is None: + raise NotImplementedError("Request without a text chunk") + + return text From 5e035063cf97f1037e209a4a917f41d4b508eee0 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 10 Jun 2024 09:09:50 +0200 Subject: [PATCH 041/340] ROCm and sliding windows fixes (#2033) * update vllm commit & fix models using sliding window * update * update commit * fix bug where tunableop is bound to cuda graph even when cuda graph are disabled * enable tunableop by default * fix sliding window * address review * dead code * precise comment * is it flaky? --- launcher/src/main.rs | 8 +++++++ server/Makefile-vllm | 4 ++-- server/text_generation_server/cli.py | 2 ++ .../layers/attention/rocm.py | 11 +++------- .../layers/attention/xpu.py | 5 +---- .../text_generation_server/models/__init__.py | 21 +++++++++++-------- .../models/flash_causal_lm.py | 7 ++++++- server/text_generation_server/server.py | 2 ++ 8 files changed, 36 insertions(+), 24 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 839e8077..21ca63d4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -483,6 +483,7 @@ fn shard_manager( rope_factor: Option, max_total_tokens: usize, max_batch_size: Option, + max_input_tokens: usize, otlp_endpoint: Option, log_level: LevelFilter, status_sender: mpsc::Sender, @@ -554,6 +555,10 @@ fn shard_manager( shard_args.push(otlp_endpoint); } + // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. + shard_args.push("--max-input-tokens".to_string()); + shard_args.push(max_input_tokens.to_string()); + // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -1013,6 +1018,7 @@ fn spawn_shards( args: &Args, cuda_graphs: Vec, max_total_tokens: usize, + max_input_tokens: usize, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1070,6 +1076,7 @@ fn spawn_shards( rope_factor, max_total_tokens, max_batch_size, + max_input_tokens, otlp_endpoint, max_log_level, status_sender, @@ -1554,6 +1561,7 @@ fn main() -> Result<(), LauncherError> { &args, cuda_graphs, max_total_tokens, + max_input_tokens, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/Makefile-vllm b/server/Makefile-vllm index ded2f5d2..8c0437ea 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa -commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 +commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ @@ -19,5 +19,5 @@ build-vllm-rocm: PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build install-vllm-rocm: build-vllm-rocm - cd vllm && git fetch && git checkout $(commit_rocm) && \ + cd vllm && git fetch && git checkout $(commit_rocm) && \ PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e . diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68b429d0..430323bc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -42,6 +42,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + max_input_tokens: Optional[int] = None, ): if sharded: assert ( @@ -98,6 +99,7 @@ def serve( dtype, trust_remote_code, uds_path, + max_input_tokens, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 535810aa..91ed5818 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -169,10 +169,8 @@ if ENGINE == "ck": ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1: - raise ValueError( - f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, k, @@ -204,10 +202,7 @@ elif ENGINE == "triton": window_size_left=-1, causal=True, ): - if window_size_left != -1: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, k, diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index d9a096f9..8b6cb87b 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -14,10 +14,7 @@ def attention( softmax_scale, window_size_left=-1, ): - if window_size_left != -1: - raise ValueError( - f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, k, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ba353c11..a61cb83b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.phi import Phi +from text_generation_server.utils.import_utils import SYSTEM + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True @@ -257,6 +259,7 @@ def get_model( speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, + max_input_tokens: int, ) -> Model: global FLASH_ATTENTION if dtype is None: @@ -410,11 +413,15 @@ def get_model( "Sharding is currently not supported with `exl2` quantization" ) sliding_window = config_dict.get("sliding_window", -1) - if sliding_window != -1 and not SUPPORTS_WINDOWING: - logger.warning( - f"Flash attention is available, but doesn't support windowing which is required by model {model_id}" + + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_input_tokens > sliding_window + ): + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - FLASH_ATTENTION = False if model_type == MAMBA: return Mamba( @@ -701,7 +708,6 @@ def get_model( ) if model_type == MISTRAL: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMistral( model_id, @@ -724,7 +730,6 @@ def get_model( ) if model_type == MIXTRAL: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMixtral( model_id, @@ -747,7 +752,6 @@ def get_model( ) if model_type == STARCODER2: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashStarcoder2( model_id, @@ -771,8 +775,7 @@ def get_model( ) if model_type == QWEN2: - sliding_window = config_dict.get("sliding_window", -1) - if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: + if FLASH_ATTENTION: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index acf77b09..d16d3710 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -902,6 +902,8 @@ class FlashCausalLM(Model): os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" ): + torch.cuda.tunable.enable() + if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": torch.cuda.tunable.tuning_enable(True) @@ -910,8 +912,11 @@ class FlashCausalLM(Model): int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") ] - else: + elif CUDA_GRAPHS is not None: tuning_sequences = CUDA_GRAPHS + else: + # For seqlen = 1, we dispatch to LLMM1 kernel. + tuning_sequences = [2, 3, 4, 5, 6, 7] tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4118b3f6..569b6925 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -199,6 +199,7 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, + max_input_tokens: int, ): async def serve_inner( model_id: str, @@ -229,6 +230,7 @@ def serve( speculate, dtype, trust_remote_code, + max_input_tokens, ) except Exception: logger.exception("Error when initializing model") From 748764efb4595b77210abd89ec16ff40c0b09dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 10 Jun 2024 09:22:29 +0200 Subject: [PATCH 042/340] Add Phi-3 medium support (#2039) Add support for Phi-3-medium The main difference between the medium and mini models is that medium uses grouped query attention with a packed QKV matrix. This change adds support for GQA with packed matrixes to `Weights.get_weights_col_packed` and uses it for Phi-3. This also allows us to remove the custom implementation of GQA from dbrx attention loading. --- .../layers/tensor_parallel.py | 17 ++- .../custom_modeling/flash_dbrx_modeling.py | 131 +----------------- .../custom_modeling/flash_gpt2_modeling.py | 7 +- .../custom_modeling/flash_llama_modeling.py | 9 ++ .../text_generation_server/utils/weights.py | 118 +++++++++++----- 5 files changed, 118 insertions(+), 164 deletions(-) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 192c2b42..6005f737 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer): return cls(linear) @classmethod - def load_qkv(cls, config, prefix: str, weights, bias: bool): + def load_qkv( + cls, + config, + prefix: str, + weights, + bias: bool, + num_heads: int, + num_key_value_heads: int, + ): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) + weight = weights.get_weights_col_packed_qkv( + prefix, + quantize=config.quantize, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) if bias: raise NotImplementedError("packed_qkv only implemented for baichuan") else: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 63ce6543..94cf6452 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,7 +20,6 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from loguru import logger from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: def load_attention(config, prefix, weights): - if config.n_heads != config.attn_config.kv_n_heads: - return _load_gqa(config, prefix, weights) - else: - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.Wqkv", - weights=weights, - bias=False, - ) - - -def _load_gqa(config, prefix: str, weights): - assert config.d_model % config.n_heads == 0 - assert config.n_heads % weights.process_group.size() == 0 - - head_dim = config.d_model // config.n_heads - world_size = weights.process_group.size() - rank = weights.process_group.rank() - - q_block_size = config.d_model // world_size - q_start = rank * q_block_size - q_stop = (rank + 1) * q_block_size - - kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size - k_offset = config.d_model - k_start = k_offset + rank * kv_block_size - k_stop = k_offset + (rank + 1) * kv_block_size - - v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim - v_start = v_offset + rank * kv_block_size - v_stop = v_offset + (rank + 1) * kv_block_size - - if config.quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - - try: - qweight_slice = weights._get_slice(f"{prefix}.qweight") - q_qweight = qweight_slice[:, q_start:q_stop] - k_qweight = qweight_slice[:, k_start:k_stop] - v_qweight = qweight_slice[:, v_start:v_stop] - - qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{config.quantize}` weight, make sure the model is already quantized" - ) - - qzeros_slice = weights._get_slice(f"{prefix}.qzeros") - q_qzeros = qzeros_slice[:, q_start:q_stop] - k_qzeros = qzeros_slice[:, k_start:k_stop] - v_qzeros = qzeros_slice[:, v_start:v_stop] - - qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1) - - scales_slice = weights._get_slice(f"{prefix}.scales") - q_scales = scales_slice[:, q_start:q_stop] - k_scales = scales_slice[:, k_start:k_stop] - v_scales = scales_slice[:, v_start:v_stop] - - scales = torch.cat([q_scales, k_scales, v_scales], dim=1) - - bits, groupsize, desc_act, quant_method = weights._get_gptq_params() - - from text_generation_server.layers import HAS_EXLLAMA - - use_exllama = ( - bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act - ) - - if config.quantize == "gptq" and quant_method == "gptq": - g_idx_slice = weights._get_slice(f"{prefix}.g_idx") - q_g_idx = g_idx_slice[:, q_start:q_stop] - k_g_idx = g_idx_slice[:, k_start:k_stop] - v_g_idx = g_idx_slice[:, v_start:v_stop] - - w = [q_g_idx, k_g_idx, v_g_idx] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif config.quantize == "gptq" and quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conveersion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) - // groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=bits, - groupsize=groupsize, - use_exllama=use_exllama, - ) - elif config.quantize == "marlin": - # NOTE: at the time marlin support was added, the only model that - # exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2), - # but it requires manual concatenation of weight files. - raise RuntimeError("dbrx models with marlin quantization are not yet supported") - else: - qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") - q = qkv_slice[q_start:q_stop] - k = qkv_slice[k_start:k_stop] - v = qkv_slice[v_start:v_stop] - - weight = torch.cat([q, k, v], dim=0) - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.Wqkv", + weights=weights, + bias=False, + num_heads=config.n_heads, + num_key_value_heads=config.attn_config.kv_n_heads, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0178c911..0c01f56a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights): rank = weights.process_group.rank() # Weights - weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize) + weight = weights.get_weights_col_packed_qkv( + f"{prefix}.c_attn", + config.quantize, + config.num_attention_heads, + config.num_attention_heads, + ) # Bias slice_ = weights._get_slice(f"{prefix}.c_attn.bias") diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index cef712f0..0d06d104 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -62,6 +62,8 @@ def load_attention(config, prefix, weights): prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": return TensorParallelColumnLinear.load_qkv( @@ -69,6 +71,8 @@ def load_attention(config, prefix, weights): prefix=f"{prefix}.W_pack", weights=weights, bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, ) # otherwise, load the default attention based on the number of heads @@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module): f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) + if config.num_key_value_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 557656e7..4d5fcb25 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass, field import os from pathlib import Path -from typing import List, Dict, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger @@ -121,49 +120,62 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str, blocks: int): + def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] - assert ( - total_size % blocks == 0 - ), f"Prepacked quantized matrix is not divisible by {blocks}" - single_size = total_size // blocks + block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) + world_size = self.process_group.size() rank = self.process_group.rank() - assert ( - single_size % world_size == 0 - ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards" - block_size = single_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - weights = [] - for block in range(blocks): - weights.append( - slice_[:, start + block * single_size : stop + block * single_size] - ) + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked qkv cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + weights.append(slice_[:, block_offset + start : block_offset + stop]) + block_offset += block_size weight = torch.cat(weights, dim=1) weight = weight.to(device=self.device) return weight - def get_weights_col_packed_qkv(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 3) + def get_weights_col_packed_qkv( + self, + prefix: str, + quantize: str, + num_heads: int, + num_key_value_heads: int, + ): + return self.get_weights_col_packed( + prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + ) def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): return self.get_weights_col_packed(prefix, quantize, 2) - def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int): + def get_weights_col_packed( + self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] + ): """ Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor + already alternating Q,K,V within the main tensor. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. """ if quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight", blocks) + qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -171,8 +183,8 @@ class Weights: bits, groupsize, _, quant_method = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) - scales = self._get_qweight(f"{prefix}.scales", blocks) + qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) + scales = self._get_qweight(f"{prefix}.scales", block_sizes) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq": @@ -205,27 +217,31 @@ class Weights: elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeight - B = self._get_qweight(f"{prefix}.B", blocks) - s = self._get_qweight(f"{prefix}.s", blocks) + B = self._get_qweight(f"{prefix}.B", block_sizes) + s = self._get_qweight(f"{prefix}.s", block_sizes) weight = MarlinWeight(B=B, s=s) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] - assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" - single_size = total_size // blocks + block_sizes = _blocks_to_block_sizes( + total_size=total_size, blocks=block_sizes + ) + world_size = self.process_group.size() rank = self.process_group.rank() - assert ( - single_size % world_size == 0 - ), f"Prepacked qkv cannot be sharded across {world_size} shards" - block_size = single_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size tensors = [] - for i in range(blocks): - tensor = slice_[start + i * single_size : stop + i * single_size] + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked weights cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + tensor = slice_[block_offset + start : block_offset + stop] tensors.append(tensor) + block_offset += block_size weight = torch.cat(tensors, dim=0) weight = weight.to(device=self.device) weight = weight.to(dtype=self.dtype) @@ -593,3 +609,31 @@ class Weights: self.quant_method = "awq" except Exception: pass + + +def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: + """ + Convert block count or proportions to block sizes. + + This function accepts + + - The number of blocks (int), in which case the block size is + total_size//blocks; or + - A list of block sizes (List[int]). + + In the latter case, if sum(blocks) < total_size, the ratios between + the block sizes will be preserved. For instance, if blocks is + [2, 1, 1] and total_size is 1024, the returned block sizes are + [512, 256, 256]. + """ + if isinstance(blocks, list): + total_blocks = sum(blocks) + assert ( + total_size % total_blocks == 0 + ), f"Cannot split {total_size} in proportional blocks: {blocks}" + part_size = total_size // total_blocks + return [part_size * block for block in blocks] + else: + assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" + single_size = total_size // blocks + return [single_size] * blocks From ac733178949f2648e0a8ba3296081d9286839e1a Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 10 Jun 2024 17:54:13 +0200 Subject: [PATCH 043/340] feat(ci): add trufflehog secrets detection (#2038) --- .github/workflows/trufflehog.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .github/workflows/trufflehog.yml diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 00000000..b8a3316e --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,22 @@ +on: + push: + +name: Secret Leaks + +permissions: + contents: read + id-token: write + issues: write + pull-requests: write + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main + From 5381fa7393fdc0560f99d7e1bce81c4d6bd9f841 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 10 Jun 2024 18:16:53 +0200 Subject: [PATCH 044/340] fix(ci): remove unnecessary permissions (#2045) --- .github/workflows/trufflehog.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index b8a3316e..8bc60eff 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -5,9 +5,6 @@ name: Secret Leaks permissions: contents: read - id-token: write - issues: write - pull-requests: write jobs: trufflehog: From eb8b76d1d28f60dae0d6c305029ed3dfe2e965e2 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:30:29 +0200 Subject: [PATCH 045/340] Update LLMM1 bound (#2050) update commit --- server/Makefile-vllm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 8c0437ea..2f2b5ef6 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa -commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0 +commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ From 99c947452df58ac7a2bd79abc3699e72cc9816b6 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Jun 2024 10:44:56 -0400 Subject: [PATCH 046/340] Support chat response format (#2046) * feat: support response_format in chat * fix: adjust typos * fix: add trufflehog lint --- .github/workflows/trufflehog.yml | 1 - ...st_grammar_response_format_llama_json.json | 23 ++++ .../test_grammar_response_format_llama.py | 101 ++++++++++++++++++ router/src/lib.rs | 8 ++ router/src/server.rs | 30 ++++-- 5 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json create mode 100644 integration-tests/models/test_grammar_response_format_llama.py diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 8bc60eff..b406d43b 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -16,4 +16,3 @@ jobs: fetch-depth: 0 - name: Secret Scanning uses: trufflesecurity/trufflehog@main - diff --git a/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json new file mode 100644 index 00000000..83390832 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}", + "role": "assistant" + } + } + ], + "created": 1718044128, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.5-dev0-native", + "usage": { + "completion_tokens": 39, + "prompt_tokens": 136, + "total_tokens": 175 + } +} diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py new file mode 100644 index 00000000..9c4c048e --- /dev/null +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -0,0 +1,101 @@ +import pytest +import requests +from pydantic import BaseModel +from typing import List + + +@pytest.fixture(scope="module") +def llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + num_shard=1, + disable_grammar_support=False, + use_flash_attention=False, + max_batch_prefill_tokens=3000, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def llama_grammar(llama_grammar_handle): + await llama_grammar_handle.health(300) + return llama_grammar_handle.client + + +@pytest.mark.asyncio +async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): + + class Weather(BaseModel): + unit: str + temperature: List[int] + + # send the request + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + "seed": 42, + "max_tokens": 500, + "response_format": {"type": "json_object", "value": Weather.schema()}, + }, + ) + + chat_completion = response.json() + called = chat_completion["choices"][0]["message"]["content"] + + assert response.status_code == 200 + assert ( + called + == '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}' + ) + assert chat_completion == response_snapshot + + +@pytest.mark.asyncio +async def test_grammar_response_format_llama_error_if_tools_not_installed( + llama_grammar, +): + class Weather(BaseModel): + unit: str + temperature: List[int] + + # send the request + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + "seed": 42, + "max_tokens": 500, + "tools": [], + "response_format": {"type": "json_object", "value": Weather.schema()}, + }, + ) + + # 422 means the server was unable to process the request because it contains invalid data. + assert response.status_code == 422 + assert response.json() == { + "error": "Grammar and tools are mutually exclusive", + "error_type": "grammar and tools", + } diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c49..1016019d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -89,6 +89,7 @@ pub(crate) enum GrammarType { /// JSON Schema is a declarative language that allows to annotate JSON documents /// with types and descriptions. #[serde(rename = "json")] + #[serde(alias = "json_object")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] Json(serde_json::Value), #[serde(rename = "regex")] @@ -791,6 +792,13 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, example = "null")] #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, + + /// Response format constraints for the generation. + /// + /// NOTE: A request can use `response_format` OR `tools` but not both. + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub response_format: Option, } fn default_tool_prompt() -> Option { diff --git a/router/src/server.rs b/router/src/server.rs index 30479b0e..e3c2c4f9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1016,6 +1016,7 @@ async fn chat_completions( tool_choice, tool_prompt, temperature, + response_format, .. } = req; @@ -1030,6 +1031,18 @@ async fn chat_completions( other => (true, other), }; + // response_format and tools are mutually exclusive + if response_format.is_some() && tools.as_ref().is_some() { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Grammar and tools are mutually exclusive".to_string(), + error_type: "grammar and tools".to_string(), + }), + )); + } + // extract tool grammar if present let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, @@ -1046,16 +1059,21 @@ async fn chat_completions( } }; - let grammar_with_prompt = tool_grammar + // determine the appropriate arguments for apply_chat_template + let tools_grammar_prompt = tool_grammar .as_ref() .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - let typed_grammar = grammar_with_prompt - .as_ref() - .map(|(grammar, _)| grammar.clone()); + let (tools_grammar_prompt, grammar) = match response_format { + Some(response_format) => (None, Some(response_format)), + None => ( + tools_grammar_prompt.clone(), + tools_grammar_prompt.map(|(grammar, _)| grammar.clone()), + ), + }; // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { + let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -1091,7 +1109,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: req.top_logprobs, - grammar: typed_grammar, + grammar, }, }; From e85e7ac4f950b84d58c462ca20aa02a5ef98470e Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 12 Jun 2024 18:22:20 +0200 Subject: [PATCH 047/340] fix(server): fix OPT implementation (#2061) --- .../models/custom_modeling/opt_modeling.py | 2 +- server/text_generation_server/models/gpt_neox.py | 3 +-- server/text_generation_server/models/opt.py | 4 ++-- server/text_generation_server/models/rw.py | 8 +++++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 83d62dea..9b2d01e0 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel): return_dict=return_dict, ) - logits, speculative_logits = self.lm_head(outputs) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) loss = None diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index c0e1adf2..d1f8f5be 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -85,5 +85,4 @@ class GPTNeoxSharded(CausalLM): use_cache=True, ) - logits = outputs.logits - return logits, speculative_logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 5b84f4ff..87319ef0 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -75,11 +75,11 @@ class OPTSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) - return outputs.logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index d4764ded..50f6ead8 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -71,11 +71,13 @@ class RW(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + ): # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, + use_cache=True, ) - return outputs.logits, outputs.past_key_values + + return outputs.logits, speculative_logits, outputs.past_key_values From 2fdad64eceb8a24c261ff8fcccd37d2538457491 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 12 Jun 2024 18:24:47 +0200 Subject: [PATCH 048/340] fix(layers): fix SuRotaryEmbedding (#2060) * fix(layers): fix SuRotaryEmbedding * change arange * remove logs --- .../text_generation_server/layers/rotary.py | 26 ++++++++++--------- .../models/flash_phi.py | 3 +-- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 648d28ab..c2f12189 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen - if seqlen > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.short_inv_freq - t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) - if self.scaling_factor is not None: - t /= self.scaling_factor - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + freqs = torch.cat([short_freqs, long_freqs]) + + self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 32b573a9..6a52c1d7 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -8,7 +8,6 @@ from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, - PhiConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = PhiConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize From d0a1d50fd3ded419a7795d006d58b11362d70057 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 13 Jun 2024 11:53:49 -0400 Subject: [PATCH 049/340] PR #2049 CI run (#2054) * Use minijinja's pycompat mode for python methods * fix: cargo fmt lint for pre commit --------- Co-authored-by: Armin Ronacher --- Cargo.lock | 16 ++++++++++++++-- router/Cargo.toml | 3 ++- router/src/infer/mod.rs | 13 +++++-------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8204ae83..4d537290 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1972,12 +1972,23 @@ dependencies = [ [[package]] name = "minijinja" -version = "1.0.12" -source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4" dependencies = [ "serde", ] +[[package]] +name = "minijinja-contrib" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07" +dependencies = [ + "minijinja", + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3731,6 +3742,7 @@ dependencies = [ "metrics", "metrics-exporter-prometheus", "minijinja", + "minijinja-contrib", "ngrok", "nohash-hasher", "once_cell", diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e6264be..3262e7e6 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -44,7 +44,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } -minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" } +minijinja = { version = "2.0.2" } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 20630c1b..07c334a3 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -12,6 +12,8 @@ use crate::{ use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; + use serde_json::{json, Map, Value}; use std::collections::HashMap; use std::sync::Arc; @@ -62,14 +64,7 @@ impl Infer { .find(|t| t.name == "default") .map(|t| t.template), }) - .map(|t| { - // .strip() is not supported in minijinja - // .capitalize() is not supported in minijinja but we can use | capitalize - let t = t - .replace(".strip()", " | trim") - .replace(".capitalize()", " | capitalize"); - ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) - }); + .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -277,6 +272,8 @@ struct ChatTemplate { impl ChatTemplate { fn new(template: String, bos_token: Option, eos_token: Option) -> Self { let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); From 7ce29b1ef2fa8286736d10d1ea36774a8cbb1419 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 13 Jun 2024 12:51:51 -0400 Subject: [PATCH 050/340] implement Open Inference Protocol endpoints (#1942) * feat: add kserve feature and basic routes * feat: implement infer endpoint wrapper around generate * fix: refactor and improve types * fix: improve infer and simplify * fix: cleanup and improve api docs * fix: refactor and encapsulate kserve feat in file * fix: remove typos after rebase --- router/Cargo.toml | 1 + router/src/kserve.rs | 247 +++++++++++++++++++++++++++++++++++++++++++ router/src/lib.rs | 3 + router/src/server.rs | 98 +++++++++++++---- 4 files changed, 328 insertions(+), 21 deletions(-) create mode 100644 router/src/kserve.rs diff --git a/router/Cargo.toml b/router/Cargo.toml index 3262e7e6..5bf4c00c 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -59,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } default = ["ngrok"] ngrok = ["dep:ngrok"] google = [] +kserve = [] diff --git a/router/src/kserve.rs b/router/src/kserve.rs new file mode 100644 index 00000000..b64efd38 --- /dev/null +++ b/router/src/kserve.rs @@ -0,0 +1,247 @@ +use crate::{ + default_parameters, + server::{generate_internal, ComputeType}, + Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema, +}; +use axum::extract::{Extension, Path}; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use futures::stream::FuturesUnordered; +use futures::TryStreamExt; +use reqwest::header::HeaderMap; +use reqwest::StatusCode; + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct OutputChunk { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct InferenceOutput { + pub id: String, + pub outputs: Vec, +} + +#[derive(Debug, Deserialize, ToSchema)] +pub(crate) struct InferenceRequest { + pub id: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, + pub inputs: Vec, + pub outputs: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub(crate) struct Input { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub(crate) struct Output { + pub name: String, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct LiveResponse { + pub live: bool, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ReadyResponse { + pub live: bool, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct MetadataServerResponse { + pub name: String, + pub version: String, + pub extensions: Vec, +} + +// Routes + +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/health/live", + responses( + (status = 200, description = "Service is live", body = LiveReponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_health_live() -> Result)> { + let data = LiveResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/health/ready", + responses( + (status = 200, description = "Service is ready", body = ReadyResponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_health_ready() -> Result)> { + let data = ReadyResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2", + responses( + (status = 200, description = "Metadata retrieved", body = MetadataServerResponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kerve_server_metadata() -> Result)> { + let data = MetadataServerResponse { + name: "text-generation-inference".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + extensions: vec![ + "health".to_string(), + "models".to_string(), + "metrics".to_string(), + ], + }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}", + responses( + (status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_metadata( + Path((model_name, model_version)): Path<(String, String)>, +) -> Result)> { + let data = MetadataServerResponse { + name: model_name, + version: model_version, + extensions: vec!["infer".to_string(), "ready".to_string()], + }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/infer", + request_body = Json, + responses( + (status = 200, description = "Inference executed successfully", body = InferenceOutput), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_infer( + infer: Extension, + Extension(compute_type): Extension, + Json(payload): Json, +) -> Result)> { + let id = payload.id.clone(); + let str_inputs = payload + .inputs + .iter() + .map(|input| { + std::str::from_utf8(&input.data).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "utf8".to_string(), + }), + ) + }) + }) + .collect::, _>>()?; + + if str_inputs.len() != payload.outputs.len() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Inputs and outputs length mismatch".to_string(), + error_type: "length mismatch".to_string(), + }), + )); + } + + let output_chunks = str_inputs + .iter() + .zip(&payload.outputs) + .map(|(str_input, output)| { + let generate_request = GenerateRequest { + inputs: str_input.to_string(), + parameters: payload.parameters.clone(), + }; + let infer = infer.clone(); + let compute_type = compute_type.clone(); + let span = tracing::Span::current(); + async move { + generate_internal(infer, compute_type, Json(generate_request), span) + .await + .map(|(_, Json(generation))| { + let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); + OutputChunk { + name: output.name.clone(), + shape: vec![1, generation_as_bytes.len()], + datatype: "BYTES".to_string(), + data: generation_as_bytes, + } + }) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + } + }) + .collect::>() + .try_collect::>() + .await?; + + let inference_output = InferenceOutput { + id: id.clone(), + outputs: output_chunks, + }; + + Ok((HeaderMap::new(), Json(inference_output)).into_response()) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/ready", + responses( + (status = 200, description = "Model version is ready", body = ReadyResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_metadata_ready( + Path((_model_name, _model_version)): Path<(String, String)>, +) -> Result)> { + let data = ReadyResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 1016019d..b0b93c13 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -4,6 +4,9 @@ mod infer; pub mod server; mod validation; +#[cfg(feature = "kserve")] +mod kserve; + use serde::{Deserialize, Serialize}; use tracing::warn; use utoipa::ToSchema; diff --git a/router/src/server.rs b/router/src/server.rs index e3c2c4f9..aa872df9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,6 +4,11 @@ use crate::infer::v2::SchedulerV2; use crate::infer::v3::SchedulerV3; use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +#[cfg(feature = "kserve")] +use crate::kserve::{ + kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, + kserve_model_metadata, kserve_model_metadata_ready, +}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, @@ -172,7 +177,7 @@ async fn generate( generate_internal(infer, ComputeType(compute_type), Json(req), span).await } -async fn generate_internal( +pub(crate) async fn generate_internal( infer: Extension, ComputeType(compute_type): ComputeType, Json(req): Json, @@ -1727,28 +1732,58 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; - // Define VertextApiDoc conditionally only if the "google" feature is enabled - let doc = { - // avoid `mut` if possible - #[cfg(feature = "google")] - { - use crate::VertexInstance; + #[allow(unused_mut)] // mut is needed for conditional compilation + let mut doc = ApiDoc::openapi(); - #[derive(OpenApi)] - #[openapi( - paths(vertex_compatibility), - components(schemas(VertexInstance, VertexRequest, VertexResponse)) - )] - struct VertextApiDoc; + #[cfg(feature = "google")] + { + use crate::VertexInstance; - // limiting mutability to the smallest scope necessary - let mut doc = ApiDoc::openapi(); - doc.merge(VertextApiDoc::openapi()); - doc - } - #[cfg(not(feature = "google"))] - ApiDoc::openapi() - }; + #[derive(OpenApi)] + #[openapi( + paths(vertex_compatibility), + components(schemas(VertexInstance, VertexRequest, VertexResponse)) + )] + struct VertexApiDoc; + + doc.merge(VertexApiDoc::openapi()); + } + + #[cfg(feature = "kserve")] + { + use crate::kserve::{ + InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk, + ReadyResponse, + }; + use crate::kserve::{ + __path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready, + __path_kserve_model_infer, __path_kserve_model_metadata, + __path_kserve_model_metadata_ready, + }; + + #[derive(OpenApi)] + #[openapi( + paths( + kserve_model_infer, + kserve_health_live, + kserve_health_ready, + kerve_server_metadata, + kserve_model_metadata, + kserve_model_metadata_ready, + ), + components(schemas( + InferenceOutput, + InferenceRequest, + LiveResponse, + MetadataServerResponse, + OutputChunk, + ReadyResponse, + )) + )] + struct KServeApiDoc; + + doc.merge(KServeApiDoc::openapi()); + } // Configure Swagger UI let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); @@ -1798,6 +1833,27 @@ pub async fn run( } } + #[cfg(feature = "kserve")] + { + tracing::info!("Built with `kserve` feature"); + app = app + .route( + "/v2/models/:model_name/versions/:model_version/infer", + post(kserve_model_infer), + ) + .route( + "/v2/models/:model_name/versions/:model_version", + get(kserve_model_metadata), + ) + .route("/v2/health/ready", get(kserve_health_ready)) + .route("/v2/health/live", get(kserve_health_live)) + .route("/v2", get(kerve_server_metadata)) + .route( + "/v2/models/:model_name/versions/:model_version/ready", + get(kserve_model_metadata_ready), + ); + } + // add layers after routes app = app .layer(Extension(info)) From f1f28404e7c12873cc7f43b528141e9e60f2b89e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 14 Jun 2024 09:45:42 +0200 Subject: [PATCH 051/340] Add support for GPTQ Marlin (#2052) Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience. --- .../test_flash_llama_gptq_marlin.json | 84 + ...st_flash_llama_gptq_marlin_all_params.json | 84 + .../test_flash_llama_gptq_marlin_load.json | 338 +++ .../models/test_flash_llama_gptq_marlin.py | 65 + server/Makefile | 1 - server/Makefile-marlin | 11 - server/marlin/COPYRIGHT | 20 + server/marlin/marlin_kernels/__init__.pyi | 44 + server/marlin/marlin_kernels/ext.cpp | 11 + server/marlin/marlin_kernels/ext.hh | 23 + server/marlin/marlin_kernels/gptq_marlin.cu | 1870 +++++++++++++++++ server/marlin/marlin_kernels/gptq_marlin.cuh | 76 + .../marlin_kernels/gptq_marlin_dtypes.cuh | 77 + .../marlin_kernels/gptq_marlin_repack.cu | 350 +++ .../marlin_kernels/marlin_cuda_kernel.cu | 1136 ++++++++++ server/marlin/marlin_kernels/py.typed | 0 server/marlin/setup.py | 21 + .../text_generation_server/layers/linear.py | 17 +- .../text_generation_server/layers/marlin.py | 256 ++- server/text_generation_server/models/bloom.py | 2 +- .../custom_modeling/flash_cohere_modeling.py | 2 +- .../flash_santacoder_modeling.py | 15 +- .../flash_starcoder2_modeling.py | 2 +- .../models/flash_cohere.py | 2 +- .../models/flash_dbrx.py | 2 +- .../models/flash_gemma.py | 2 +- .../models/flash_llama.py | 2 +- .../models/flash_mistral.py | 2 +- .../models/flash_neox.py | 2 +- .../models/flash_phi.py | 2 +- .../models/flash_qwen2.py | 2 +- .../text_generation_server/models/flash_rw.py | 2 +- .../models/flash_santacoder.py | 2 +- .../models/flash_starcoder2.py | 2 +- .../models/galactica.py | 2 +- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/mpt.py | 2 +- server/text_generation_server/models/opt.py | 2 +- .../text_generation_server/utils/weights.py | 253 ++- 39 files changed, 4651 insertions(+), 137 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json create mode 100644 integration-tests/models/test_flash_llama_gptq_marlin.py delete mode 100644 server/Makefile-marlin create mode 100644 server/marlin/COPYRIGHT create mode 100644 server/marlin/marlin_kernels/__init__.pyi create mode 100644 server/marlin/marlin_kernels/ext.cpp create mode 100644 server/marlin/marlin_kernels/ext.hh create mode 100644 server/marlin/marlin_kernels/gptq_marlin.cu create mode 100644 server/marlin/marlin_kernels/gptq_marlin.cuh create mode 100644 server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh create mode 100644 server/marlin/marlin_kernels/gptq_marlin_repack.cu create mode 100644 server/marlin/marlin_kernels/marlin_cuda_kernel.cu create mode 100644 server/marlin/marlin_kernels/py.typed create mode 100644 server/marlin/setup.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json new file mode 100644 index 00000000..0f99d259 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 262, + "logprob": -1.6230469, + "special": false, + "text": " " + }, + { + "id": 3270, + "logprob": -2.046875, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1425781, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.9238281, + "special": false, + "text": " request" + }, + { + "id": 13204, + "logprob": -0.076660156, + "special": false, + "text": ".method" + }, + { + "id": 624, + "logprob": -0.021987915, + "special": false, + "text": " ==" + }, + { + "id": 364, + "logprob": -0.39208984, + "special": false, + "text": " '" + }, + { + "id": 3019, + "logprob": -0.10821533, + "special": false, + "text": "POST" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n if request.method == 'POST" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json new file mode 100644 index 00000000..4152b5b3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.34375, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -2.2539062, + "special": false, + "text": "." + }, + { + "id": 578, + "logprob": -0.15563965, + "special": false, + "text": " The" + }, + { + "id": 3622, + "logprob": -0.8203125, + "special": false, + "text": " server" + }, + { + "id": 706, + "logprob": 0.0, + "special": false, + "text": " has" + }, + { + "id": 539, + "logprob": 0.0, + "special": false, + "text": " not" + }, + { + "id": 3686, + "logprob": 0.0, + "special": false, + "text": " yet" + }, + { + "id": 3288, + "logprob": 0.0, + "special": false, + "text": " sent" + }, + { + "id": 904, + "logprob": 0.0, + "special": false, + "text": " any" + }, + { + "id": 828, + "logprob": 0.0, + "special": false, + "text": " data" + }, + { + "id": 382, + "logprob": -1.5517578, + "special": false, + "text": ".\n\n" + } + ], + "top_tokens": null + }, + "generated_text": "Test request. The server has not yet sent any data.\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json new file mode 100644 index 00000000..75e90303 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 262, + "logprob": -1.6220703, + "special": false, + "text": " " + }, + { + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, + "special": false, + "text": " request" + }, + { + "id": 13204, + "logprob": -0.07672119, + "special": false, + "text": ".method" + }, + { + "id": 624, + "logprob": -0.021987915, + "special": false, + "text": " ==" + }, + { + "id": 364, + "logprob": -0.39208984, + "special": false, + "text": " '" + }, + { + "id": 3019, + "logprob": -0.10638428, + "special": false, + "text": "POST" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n if request.method == 'POST" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 262, + "logprob": -1.6220703, + "special": false, + "text": " " + }, + { + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, + "special": false, + "text": " request" + }, + { + "id": 13204, + "logprob": -0.07672119, + "special": false, + "text": ".method" + }, + { + "id": 624, + "logprob": -0.021987915, + "special": false, + "text": " ==" + }, + { + "id": 364, + "logprob": -0.39208984, + "special": false, + "text": " '" + }, + { + "id": 3019, + "logprob": -0.10638428, + "special": false, + "text": "POST" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n if request.method == 'POST" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 262, + "logprob": -1.6220703, + "special": false, + "text": " " + }, + { + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, + "special": false, + "text": " request" + }, + { + "id": 13204, + "logprob": -0.07672119, + "special": false, + "text": ".method" + }, + { + "id": 624, + "logprob": -0.021987915, + "special": false, + "text": " ==" + }, + { + "id": 364, + "logprob": -0.39208984, + "special": false, + "text": " '" + }, + { + "id": 3019, + "logprob": -0.10638428, + "special": false, + "text": "POST" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n if request.method == 'POST" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 262, + "logprob": -1.6220703, + "special": false, + "text": " " + }, + { + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, + "special": false, + "text": " request" + }, + { + "id": 13204, + "logprob": -0.07672119, + "special": false, + "text": ".method" + }, + { + "id": 624, + "logprob": -0.021987915, + "special": false, + "text": " ==" + }, + { + "id": 364, + "logprob": -0.39208984, + "special": false, + "text": " '" + }, + { + "id": 3019, + "logprob": -0.10638428, + "special": false, + "text": "POST" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n if request.method == 'POST" + } +] diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py new file mode 100644 index 00000000..9c37a644 --- /dev/null +++ b/integration-tests/models/test_flash_llama_gptq_marlin.py @@ -0,0 +1,65 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_gptq_marlin_handle(launcher): + with launcher( + "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): + await flash_llama_gptq_marlin_handle.health(300) + return flash_llama_gptq_marlin_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): + response = await flash_llama_gptq_marlin.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_marlin_all_params( + flash_llama_gptq_marlin, response_snapshot +): + response = await flash_llama_gptq_marlin.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_marlin_load( + flash_llama_gptq_marlin, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/Makefile b/server/Makefile index f2a45cc0..5257b876 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,7 +3,6 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-eetq -include Makefile-marlin include Makefile-selective-scan unit-tests: diff --git a/server/Makefile-marlin b/server/Makefile-marlin deleted file mode 100644 index 816546af..00000000 --- a/server/Makefile-marlin +++ /dev/null @@ -1,11 +0,0 @@ -marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c - -build-marlin: - if [ ! -d 'marlin' ]; then \ - pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/IST-DASLab/marlin.git marlin; \ - fi - cd marlin && git fetch && git checkout $(marlin_commit) && python setup.py build - -install-marlin: build-marlin - cd marlin && git fetch && git checkout $(marlin_commit) && pip install -e . diff --git a/server/marlin/COPYRIGHT b/server/marlin/COPYRIGHT new file mode 100644 index 00000000..69f3b8e6 --- /dev/null +++ b/server/marlin/COPYRIGHT @@ -0,0 +1,20 @@ +These kernels were vendored from VLLM. The Marlin kernels were developed +by Elias Frantar and extended by Neural Magic. + +--- + +Copyright (C) Marlin.2024 Elias Frantar +Modified by Neural Magic +Copyright 2024 The vLLM team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi new file mode 100644 index 00000000..73597f0c --- /dev/null +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -0,0 +1,44 @@ +import torch + +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, +) -> torch.Tensor: + """ + Matrix multiplication using Marlin kernels. This is an extension of + `marlin_gemm` that supports converted GPTQ kernels. + """ + ... + +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + """Repack GPTQ parameters for Marlin kernels.""" + ... + +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + """ + Matrix multiplication using Marlin kernels. + """ + ... diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp new file mode 100644 index 00000000..5855714d --- /dev/null +++ b/server/marlin/marlin_kernels/ext.cpp @@ -0,0 +1,11 @@ +#include + +#include "ext.hh" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gptq_marlin_gemm", &gptq_marlin_gemm, + "Marlin gemm with GPTQ compatibility"); + m.def("gptq_marlin_repack", &gptq_marlin_repack, + "Repack GPTQ parameters for Marlin"); + m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); +} diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh new file mode 100644 index 00000000..9ea01a3f --- /dev/null +++ b/server/marlin/marlin_kernels/ext.hh @@ -0,0 +1,23 @@ +#pragma once + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &g_idx, + torch::Tensor &perm, torch::Tensor &workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu new file mode 100644 index 00000000..0beb9de1 --- /dev/null +++ b/server/marlin/marlin_kernels/gptq_marlin.cu @@ -0,0 +1,1870 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "gptq_marlin.cuh" +#include "gptq_marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace gptq_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) {} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +template +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { + res = __hmul2(res, s[0]); + } + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +template +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, + void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); + } + + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by having + // a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || + (g_idx.size(0) == size_k && perm.size(0) == size_k), + "Unexpected g_idx.size(0) = ", g_idx.size(0), + " and perm.size(0) = ", perm.size(0), + ", where size_k = ", size_k); + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(0) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + num_groups = b_scales.size(0); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, + is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin.cuh b/server/marlin/marlin_kernels/gptq_marlin.cuh new file mode 100644 index 00000000..42af4495 --- /dev/null +++ b/server/marlin/marlin_kernels/gptq_marlin.cuh @@ -0,0 +1,76 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace gptq_marlin { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace gptq_marlin diff --git a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh b/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh new file mode 100644 index 00000000..ca1b7099 --- /dev/null +++ b/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh @@ -0,0 +1,77 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "gptq_marlin.cuh" +#include +#include + +namespace gptq_marlin { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace gptq_marlin + +#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin_repack.cu b/server/marlin/marlin_kernels/gptq_marlin_repack.cu new file mode 100644 index 00000000..4adc158e --- /dev/null +++ b/server/marlin/marlin_kernels/gptq_marlin_repack.cu @@ -0,0 +1,350 @@ +#include "gptq_marlin.cuh" + +namespace gptq_marlin { + +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + + #pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace gptq_marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); + TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", pack_factor = ", pack_factor); + TORCH_CHECK(b_q_weight.size(1) == size_n, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not size_n = ", size_n); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = + torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", has_perm = ", has_perm); + } + + return out; +} + +#endif diff --git a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu new file mode 100644 index 00000000..d124c014 --- /dev/null +++ b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu @@ -0,0 +1,1136 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 16 * thread_k_blocks / + 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 2 * ((threads / 32) / + (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == + -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +static constexpr int pack_factor_4bit = + 8; // We have 8 4-bit vals inside a 32 bit + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_n = -1, int sms = -1, int max_par = 16) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { + throw std::runtime_error( + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); + } + + // Uncomment for debug + // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + + // ", thread_n = " + str(th_config.thread_n) + + // ", num_threads = " + str(th_config.num_threads) + " for + // MKN = [" + str(prob_m) + + // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_m == 0 || prob_n == 0 || prob_k == 0) { + return; + } + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + int* locks = (int*)workspace; + + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) { + } + CALL_IF(8, 8, 256) + CALL_IF(16, 4, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace marlin + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k) { + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % marlin::tile_size == 0, + "size_k = " + str(size_k) + + " is not divisible by tile_size = " + str(marlin::tile_size)); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(marlin::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin::tile_size)); + + int actual_size_n = + (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize + if (b_scales.size(0) != 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + } + int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 128, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK( + size_n % marlin::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, + sms, marlin::max_par); + + return c; +} diff --git a/server/marlin/marlin_kernels/py.typed b/server/marlin/marlin_kernels/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/server/marlin/setup.py b/server/marlin/setup.py new file mode 100644 index 00000000..844e1139 --- /dev/null +++ b/server/marlin/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +extra_compile_args = [] + +setup( + name="marlin_kernels", + ext_modules=[ + CUDAExtension( + name="marlin_kernels", + sources=[ + "marlin_kernels/gptq_marlin.cu", + "marlin_kernels/gptq_marlin_repack.cu", + "marlin_kernels/marlin_cuda_kernel.cu", + "marlin_kernels/ext.cpp", + ], + extra_compile_args=extra_compile_args, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 3537b62d..d40b192f 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,6 +1,7 @@ from typing import Optional import torch from torch.nn import functional as F +from text_generation_server.layers.marlin import GPTQMarlinLinear from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": @@ -223,13 +224,23 @@ def get_linear(weight, bias, quantize): "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" ) elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinLinear, MarlinWeight + from text_generation_server.layers.marlin import ( + GPTQMarlinWeight, + MarlinLinear, + MarlinWeight, + ) - if not isinstance(weight, MarlinWeight): + if isinstance(weight, GPTQMarlinWeight): + linear = GPTQMarlinLinear( + weight=weight, + bias=bias, + ) + elif isinstance(weight, MarlinWeight): + linear = MarlinLinear(weight=weight, bias=bias) + else: raise NotImplementedError( f"The passed weight is not `marlin` compatible, loader needs to be updated." ) - linear = MarlinLinear(B=weight.B, s=weight.s, bias=bias) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a860d84b..4d4c635e 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,13 +1,15 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple, List import torch import torch.nn as nn +from text_generation_server.utils.import_utils import SYSTEM + try: - import marlin + import marlin_kernels except ImportError: - marlin = None + marlin_kernels = None try: major, _minor = torch.cuda.get_device_capability() @@ -15,9 +17,204 @@ try: except Exception: has_sm_8_0 = False + +GPTQ_MARLIN_BITS = [4, 8] +GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 +def _check_marlin_kernels(): + if not (SYSTEM == "cuda" and has_sm_8_0): + raise NotImplementedError( + "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." + ) + + if marlin_kernels is None: + raise NotImplementedError( + "marlin is not installed, install it with: pip install server/marlin" + ) + + +def _check_valid_shape(in_features: int, out_features: int): + if (in_features % 128 != 0 or out_features % 64 != 0) and ( + in_features % 64 != 0 or out_features % 128 != 0 + ): + raise ValueError( + f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." + " The shape elements must be divisible by (128, 64) or (64, 128)." + ) + + +# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 +def _get_perms() -> Tuple[List[int], List[int]]: + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +_scale_perm, _scale_perm_single = _get_perms() + + +def permute_scales(scales: torch.Tensor): + out_features = scales.shape[1] + if scales.shape[0] == 1: + scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + else: + scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm] + return scales.reshape((-1, out_features)).contiguous() + + +@dataclass +class GPTQMarlinWeight: + """ + Repacked GPTQ Marlin weights. + """ + + qweight: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + perm: torch.Tensor + bits: int + is_full_k: bool + + def __post_init__(self): + assert self.qweight.dtype == torch.int32 + assert self.scales.dtype == torch.float16 + assert self.g_idx.dtype == torch.int32 + assert self.perm.dtype == torch.int32 + + +def repack_gptq_for_marlin( + *, + qweight: torch.Tensor, + scales: torch.Tensor, + g_idx: torch.Tensor, + bits: int, + desc_act: bool, + groupsize: int, + sym: bool, + sharded_infeatures: bool, +) -> GPTQMarlinWeight: + """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" + _check_marlin_kernels() + assert marlin_kernels is not None + + if bits not in GPTQ_MARLIN_BITS: + supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) + raise RuntimeError( + f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" + ) + + if groupsize not in GPTQ_MARLIN_GROUP_SIZES: + supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) + raise RuntimeError( + f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" + ) + if not sym: + raise RuntimeError( + "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." + ) + + weights_per_int = 32 // bits + in_features = qweight.shape[0] * weights_per_int + out_features = qweight.shape[1] + + if in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisible by group size ({groupsize})" + ) + + if desc_act and groupsize != -1: + perm = torch.argsort(g_idx).to(torch.int) + g_idx = g_idx[perm] + else: + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) + + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, bits + ) + + scales = permute_scales(scales) + + is_full_k = not (desc_act and sharded_infeatures) + + return GPTQMarlinWeight( + qweight=repacked, + scales=scales, + g_idx=g_idx, + perm=perm, + bits=bits, + is_full_k=is_full_k, + ) + + +class GPTQMarlinLinear(nn.Module): + """ + Linear layer for GPTQ weights that were converted for the GPTQ-Marlin + kernels. + """ + + def __init__( + self, + *, + weight: GPTQMarlinWeight, + bias: Optional[torch.Tensor], + ): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE + out_features = weight.scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.bits = weight.bits + self.is_full_k = weight.is_full_k + + self.register_buffer("qweight", weight.qweight) + self.register_buffer("scales", weight.scales) + self.register_buffer("g_idx", weight.g_idx) + self.register_buffer("perm", weight.perm) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.gptq_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.g_idx, + self.perm, + self.workspace, + self.bits, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + self.is_full_k, + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + @dataclass class MarlinWeight: """ @@ -31,28 +228,20 @@ class MarlinWeight: B: torch.Tensor s: torch.Tensor + def __post_init__(self): + assert self.B.dtype == torch.int32 + assert self.s.dtype == torch.float16 + class MarlinLinear(nn.Module): - def __init__( - self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor] - ): + def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): super().__init__() - if not has_sm_8_0: - raise NotImplementedError( - "Using quantized marlin models requires CUDA capability 8.0 or later" - ) + _check_marlin_kernels() + assert marlin_kernels is not None - if marlin is None: - raise NotImplementedError( - "You do not seem to have marlin installed, either install it (cd server && make install-marlin)" - ) - - assert B.dtype == torch.int32 - assert s.dtype == torch.float16 - - in_features = B.shape[0] * MARLIN_TILE_SIZE - out_features = s.shape[1] + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE + out_features = weight.s.shape[1] assert ( in_features % 128 == 0 ), f"Number of input features ({in_features}) not divisable by 128" @@ -60,35 +249,36 @@ class MarlinLinear(nn.Module): out_features % 256 == 0 ), f"Number of output features ({out_features}) not divisable by 256" - group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0] - assert group_size in { + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + assert groupsize in { -1, 128, - }, f"Group size must be -1 or 128, was {group_size}" + }, f"Group size must be -1 or 128, was {groupsize}" - self.register_buffer("B", B) - self.register_buffer("s", s) + self.register_buffer("B", weight.B) + self.register_buffer("s", weight.s) if bias is not None: self.register_buffer("bias", bias) else: self.bias = None self.workspace = torch.zeros( - out_features // 128 * 16, dtype=torch.int, device=B.device + out_features // 64 * 16, dtype=torch.int, device=weight.B.device ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin is not None - C = torch.empty( - A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device - ) - marlin.mul( - A.view((-1, A.shape[-1])), + assert marlin_kernels is not None + + C = marlin_kernels.marlin_gemm( + A.view(-1, A.shape[-1]), self.B, - C.view((-1, C.shape[-1])), self.s, self.workspace, + A.shape[0], + self.s.shape[1], + A.shape[1], ) + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) if self.bias is not None: C += self.bias diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 65c9f317..38006502 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -83,7 +83,7 @@ class BLOOMSharded(CausalLM): process_group=self.process_group, prefix="transformer", ) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = BloomForCausalLM(config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 764dc6e2..6d315ba5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -166,7 +166,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4fa6516e..2ae0908c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -81,16 +81,11 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - ( - bits, - groupsize, - _, - quant_method, - ) = weights._get_gptq_params() - if quant_method == "gptq": + gptq_params = weights._get_gptq_params() + if gptq_params.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - elif quant_method == "awq": + elif gptq_params.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, @@ -105,8 +100,8 @@ def _load_multi_mqa_gptq( qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=bits, - groupsize=groupsize, + bits=gptq_params.bits, + groupsize=gptq_params.groupsize, use_exllama=HAS_EXLLAMA, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 37486e9d..c3e2e099 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -130,7 +130,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index b907ee08..1077d78e 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -55,7 +55,7 @@ class FlashCohere(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashCohereForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index d5eb1a6e..ffb6d5a6 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -80,7 +80,7 @@ class FlashDbrx(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashDbrxForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 358883e6..1b7b2772 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -53,7 +53,7 @@ class FlashGemma(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) # TODO hardcoded diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index c5cbd2b8..e27f0da2 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "exl2"]: + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) prefix = "" diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 081c2e2c..0fdda6d2 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -68,7 +68,7 @@ class BaseFlashMistral(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) prefix = "" diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index adefaeb2..d3871c2f 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -58,7 +58,7 @@ class FlashNeoXSharded(FlashCausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashGPTNeoXForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 6a52c1d7..0cc67cec 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -53,7 +53,7 @@ class FlashPhi(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashPhiForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 75285863..9fcfce9d 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -62,7 +62,7 @@ class FlashQwen2(BaseFlashMistral): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) model = Qwen2ForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e6350611..187f26a8 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -67,7 +67,7 @@ class FlashRWSharded(FlashCausalLM): config.quantize = quantize config.speculator = speculator - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashRWForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2ad36b93..a8d84fca 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -69,7 +69,7 @@ class FlashSantacoderSharded(FlashCausalLM): process_group=self.process_group, aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 5533c9d9..1ac731be 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -61,7 +61,7 @@ class FlashStarcoder2(BaseFlashMistral): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "marlin"]: weights._set_gptq_params(model_id, revision) model = FlashStarcoder2ForCausalLM(config, weights) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d0f2b915..f39bd1e9 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -205,7 +205,7 @@ class GalacticaSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = OPTForCausalLM(config, weights) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index d1f8f5be..8d2cb0e1 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -58,7 +58,7 @@ class GPTNeoxSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = GPTNeoxForCausalLM(config, weights) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 8d8b4909..65180e73 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -82,7 +82,7 @@ class MPTSharded(CausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) config.quantize = quantize diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 87319ef0..1f4fbfcd 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -56,7 +56,7 @@ class OPTSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) - if config.quantize == "gptq": + if config.quantize in ["gptq", "marlin"]: weights._set_gptq_params(model_id, revision) model = OPTForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4d5fcb25..45cfc073 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError @@ -9,6 +10,15 @@ import json from text_generation_server.utils.log import log_once +@dataclass +class _GPTQParams: + bits: int + groupsize: int + desc_act: bool + quant_method: str + sym: bool + + class Weights: def __init__( self, @@ -181,15 +191,15 @@ class Weights: f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) - bits, groupsize, _, quant_method = self._get_gptq_params() + gptq_params = self._get_gptq_params() qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) scales = self._get_qweight(f"{prefix}.scales", block_sizes) scales = scales.to(dtype=self.dtype) - if quantize == "gptq" and quant_method == "gptq": + if quantize == "gptq" and gptq_params.quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") - elif quantize == "gptq" and quant_method == "awq": + elif quantize == "gptq" and gptq_params.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) @@ -199,8 +209,11 @@ class Weights: qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) g_idx = ( - torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) - // groupsize + torch.arange( + qweight.shape[0] * (32 // gptq_params.bits), + device=qweight.device, + ) + // gptq_params.groupsize ).to(dtype=torch.int32) else: g_idx = None @@ -210,16 +223,43 @@ class Weights: qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=bits, - groupsize=groupsize, + bits=gptq_params.bits, + groupsize=gptq_params.groupsize, use_exllama=False, ) elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeight + from text_generation_server.layers.marlin import ( + MarlinWeight, + repack_gptq_for_marlin, + ) - B = self._get_qweight(f"{prefix}.B", block_sizes) - s = self._get_qweight(f"{prefix}.s", block_sizes) - weight = MarlinWeight(B=B, s=s) + quant_method = getattr(self, "quant_method", "marlin") + if quant_method == "gptq": + gptq_params = self._get_gptq_params() + try: + qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + scales = self._get_qweight(f"{prefix}.scales", block_sizes) + g_idx = self.get_tensor(f"{prefix}.g_idx") + weight = repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) + + else: + B = self._get_qweight(f"{prefix}.B", block_sizes) + s = self._get_qweight(f"{prefix}.s", block_sizes) + weight = MarlinWeight(B=B, s=s) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] @@ -295,20 +335,23 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - bits, groupsize, desc_act, quant_method = self._get_gptq_params() + gptq_params = self._get_gptq_params() from text_generation_server.layers.gptq import HAS_EXLLAMA use_exllama = ( - bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act + gptq_params.bits == 4 + and HAS_EXLLAMA + and quantize == "gptq" + and not gptq_params.desc_act ) - if quantize == "gptq" and quant_method == "gptq": + if quantize == "gptq" and gptq_params.quant_method == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - elif quantize == "gptq" and quant_method == "awq": + elif quantize == "gptq" and gptq_params.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) @@ -322,9 +365,10 @@ class Weights: else: g_idx = ( torch.arange( - qweight.shape[0] * (32 // bits), device=qweight.device + qweight.shape[0] * (32 // gptq_params.bits), + device=qweight.device, ) - // groupsize + // gptq_params.groupsize ).to(dtype=torch.int32) else: g_idx = None @@ -334,24 +378,62 @@ class Weights: qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=bits, - groupsize=groupsize, + bits=gptq_params.bits, + groupsize=gptq_params.groupsize, use_exllama=use_exllama, ) elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeight + from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + MarlinWeight, + repack_gptq_for_marlin, + ) - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1) + quant_method = getattr(self, "quant_method", "marlin") + if quant_method == "gptq": + gptq_params = self._get_gptq_params() + try: + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], + dim=1, + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) - weight = MarlinWeight(B=B, s=s) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + weight = repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) + else: + try: + B = torch.cat( + [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight, make sure the model is already quantized" + ) + s = torch.cat( + [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -401,12 +483,12 @@ class Weights: elif quantize == "gptq": use_exllama = True - bits, groupsize, desc_act, quant_method = self._get_gptq_params() + gptq_params = self._get_gptq_params() - if bits != 4: + if gptq_params.bits != 4: use_exllama = False - if desc_act: + if gptq_params.desc_act: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False @@ -417,9 +499,9 @@ class Weights: "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - if quant_method == "gptq": + if gptq_params.quant_method == "gptq": g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - elif quant_method == "awq": + elif gptq_params.quant_method == "awq": g_idx = None if self.process_group.size() > 1: @@ -428,7 +510,10 @@ class Weights: not torch.equal( g_idx.cpu(), torch.tensor( - [i // groupsize for i in range(g_idx.shape[0])], + [ + i // gptq_params.groupsize + for i in range(g_idx.shape[0]) + ], dtype=torch.int32, ), ) @@ -455,7 +540,7 @@ class Weights: else: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - if use_exllama and groupsize != -1: + if use_exllama and gptq_params.groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) else: @@ -465,7 +550,7 @@ class Weights: if use_exllama and g_idx is not None: g_idx = g_idx - g_idx[0] - if quant_method == "awq": + if gptq_params.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) @@ -479,9 +564,10 @@ class Weights: else: g_idx = ( torch.arange( - qweight.shape[0] * (32 // bits), device=qweight.device + qweight.shape[0] * (32 // gptq_params.bits), + device=qweight.device, ) - // groupsize + // gptq_params.groupsize ).to(dtype=torch.int32) weight = GPTQWeight( @@ -489,14 +575,14 @@ class Weights: qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=bits, - groupsize=groupsize, + bits=gptq_params.bits, + groupsize=gptq_params.groupsize, use_exllama=use_exllama, ) elif quantize == "awq": from text_generation_server.layers.gptq import GPTQWeight - bits, groupsize, _, _ = self._get_gptq_params() + gptq_params = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -515,38 +601,74 @@ class Weights: qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=bits, - groupsize=groupsize, + bits=gptq_params.bits, + groupsize=gptq_params.groupsize, use_exllama=use_exllama, ) elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeight + from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + MarlinWeight, + repack_gptq_for_marlin, + ) - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + quant_method = getattr(self, "quant_method", "marlin") + if quant_method == "gptq": + log_once(logger.info, "Converting GPTQ model to Marlin packing format.") + gptq_params = self._get_gptq_params() + + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + if gptq_params.desc_act or gptq_params.groupsize == -1: + scales = self.get_tensor(f"{prefix}.scales") + else: + scales = self.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = self.process_group.size() > 1 + + weight = repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=sharded_in_features, ) - - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when group_size == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) + try: + B = self.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = self.get_tensor(f"{prefix}.s") + else: + s = self.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int, int, str]: + def _get_gptq_params(self) -> _GPTQParams: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() desc_act = False + sym = True quant_method = "gptq" except (SafetensorError, RuntimeError) as e: try: @@ -554,10 +676,17 @@ class Weights: groupsize = self.gptq_groupsize desc_act = getattr(self, "gptq_desc_act", False) quant_method = getattr(self, "quant_method", "gptq") + sym = getattr(self, "sym", True) except Exception: raise e - return bits, groupsize, desc_act, quant_method + return _GPTQParams( + bits=bits, + desc_act=desc_act, + groupsize=groupsize, + quant_method=quant_method, + sym=sym, + ) def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -574,6 +703,7 @@ class Weights: self.gptq_groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models self.quant_method = data["quantization_config"]["quant_method"] + self.gptq_sym = data["quantization_config"]["sym"] self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" @@ -588,6 +718,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] + self.gptq_sym = data["sym"] self.gptq_desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" From b07a2518d9c2ecdc7b3ecc840bf095ae9dcaafbb Mon Sep 17 00:00:00 2001 From: Tiezhen WANG <38108242+xianbaoqian@users.noreply.github.com> Date: Fri, 14 Jun 2024 17:59:33 +0800 Subject: [PATCH 052/340] Update the link for qwen2 (#2068) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update the link for qwen2 * Fix Qwen2 model URL in model table * Fix too eager staging --------- Co-authored-by: Daniël de Kok --- docs/source/supported_models.md | 2 +- server/text_generation_server/models/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 4b6cf731..3468e988 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -20,7 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) - [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) - [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) -- [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) +- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f) - [Opt](https://huggingface.co/facebook/opt-6.7b) - [T5](https://huggingface.co/google/flan-t5-xxl) - [Galactica](https://huggingface.co/facebook/galactica-120b) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a61cb83b..76dca3dc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -196,7 +196,7 @@ class ModelType(enum.Enum): QWEN2 = { "type": "qwen2", "name": "Qwen 2", - "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", } OPT = { "type": "opt", From 8ee52e91f44f5d1863bda8cd39807c3f954a3215 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:28:34 +0200 Subject: [PATCH 053/340] Adding architecture document (#2044) * doc: adding architecture document * doc: add architecture to toctree * fix: avoid cargo lock changes * fix: avoid cargo lock tweak --------- Co-authored-by: drbh --- docs/source/_toctree.yml | 2 + docs/source/architecture.md | 227 ++++++++++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 docs/source/architecture.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a7351a33..7599562a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -17,6 +17,8 @@ title: Supported Models and Hardware - local: messages_api title: Messages API + - local: architecture + title: Internal Architecture title: Getting started - sections: - local: basic_tutorials/consuming_tgi diff --git a/docs/source/architecture.md b/docs/source/architecture.md new file mode 100644 index 00000000..b7885879 --- /dev/null +++ b/docs/source/architecture.md @@ -0,0 +1,227 @@ +# Text Generation Inference Architecture + +This document aims at describing the architecture of Text Generation Inference (TGI), by describing the call flow between the separate components. + +A high-level architecture diagram can be seen here: + +![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) + +This diagram shows well there are these separate components: + +- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server. +- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent. +- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments. + +The router and the model server can be two different machines, they do not need to be deployed together. + +## The Router + +This component is a rust web server binary that accepts HTTP requests using the custom [HTTP API](https://huggingface.github.io/text-generation-inference/), as well as OpenAI's [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api). +The router receives the API calls and handles the "baches" logic (and introduction to batching can be found [here](https://github.com/huggingface/text-generation-inference/blob/main/router/README.md)). +It uses different strategies to reduce latency between requests and responses, especially oriented to decoding latency. It will use queues, schedulers, and block allocators to achieve that and produce batched requests that it will then be sent to the model server. + +### Router's command line + +The router command line will be the way to pass parameters to it (it does not rely on configuration file): + +``` +Text Generation Webserver + +Usage: text-generation-router [OPTIONS] + +Options: + --max-concurrent-requests + [env: MAX_CONCURRENT_REQUESTS=] [default: 128] + --max-best-of + [env: MAX_BEST_OF=] [default: 2] + --max-stop-sequences + [env: MAX_STOP_SEQUENCES=] [default: 4] + --max-top-n-tokens + [env: MAX_TOP_N_TOKENS=] [default: 5] + --max-input-tokens + [env: MAX_INPUT_TOKENS=] [default: 1024] + --max-total-tokens + [env: MAX_TOTAL_TOKENS=] [default: 2048] + --waiting-served-ratio + [env: WAITING_SERVED_RATIO=] [default: 1.2] + --max-batch-prefill-tokens + [env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096] + --max-batch-total-tokens + [env: MAX_BATCH_TOTAL_TOKENS=] + --max-waiting-tokens + [env: MAX_WAITING_TOKENS=] [default: 20] + --max-batch-size + [env: MAX_BATCH_SIZE=] + --hostname + [env: HOSTNAME=] [default: 0.0.0.0] + -p, --port + [env: PORT=] [default: 3000] + --master-shard-uds-path + [env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0] + --tokenizer-name + [env: TOKENIZER_NAME=] [default: bigscience/bloom] + --tokenizer-config-path + [env: TOKENIZER_CONFIG_PATH=] + --revision + [env: REVISION=] + --validation-workers + [env: VALIDATION_WORKERS=] [default: 2] + --json-output + [env: JSON_OUTPUT=] + --otlp-endpoint + [env: OTLP_ENDPOINT=] + --cors-allow-origin + [env: CORS_ALLOW_ORIGIN=] + --ngrok + [env: NGROK=] + --ngrok-authtoken + [env: NGROK_AUTHTOKEN=] + --ngrok-edge + [env: NGROK_EDGE=] + --messages-api-enabled + [env: MESSAGES_API_ENABLED=] + --disable-grammar-support + [env: DISABLE_GRAMMAR_SUPPORT=] + --max-client-batch-size + [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] + -h, --help + Print help + -V, --version + Print version +``` + +## The Model Server + +The model server is a python server, capable of starting a server waiting for gRPC requests, loads a given model, perform sharding to provide [tensor parallelism](https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism), and stays alive while waiting for new requests. +The model server supports models instantiated using Pytorch and optimized for inference mainly on CUDA/ROCM. + +### Model Server Variants + +Several variants of the model server exist that are actively supported by Hugging Face: + +- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). +- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). +- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). +- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). + +Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations. + +### Command Line Interface + +The official command line interface (CLI) for the server supports three subcommands, `download-weights`, `quantize` and `serve`: + +- `download-weights` will download weights from the hub and, in some variants it will convert weights to a format that is adapted to the given implementation; +- `quantize` will allow to quantize a model using the `qptq` package. This feature is not available nor supported on all variants; +- `serve` will start the server that load a model (or a model shard), receives gRPC calls from the router, performs an inference and provides a formatted response to the given request. + +Serve's command line parameters on the TGI repository are these: + +``` + Usage: cli.py serve [OPTIONS] MODEL_ID + +╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * model_id TEXT [default: None] [required] │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --revision TEXT [default: None] │ +│ --sharded --no-sharded [default: no-sharded] │ +│ --quantize [bitsandbytes|bitsandbytes [default: None] │ +│ -nf4|bitsandbytes-fp4|gptq │ +│ |awq|eetq|exl2|fp8] │ +│ --speculate INTEGER [default: None] │ +│ --dtype [float16|bfloat16] [default: None] │ +│ --trust-remote-code --no-trust-remote-code [default: │ +│ no-trust-remote-code] │ +│ --uds-path PATH [default: │ +│ /tmp/text-generation-serve… │ +│ --logger-level TEXT [default: INFO] │ +│ --json-output --no-json-output [default: no-json-output] │ +│ --otlp-endpoint TEXT [default: None] │ +│ --help Show this message and exit. │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +``` + +Note that some variants might support different parameters, and they could possibly accept more options that can be passed on using environment variables. + +## Call Flow + +Once both components are initialized, weights downloaded and model server is up and running, router and model server exchange data and info through the gRPC call. There are currently two supported schemas, [v2](https://github.com/huggingface/text-generation-inference/blob/main/proto/generate.proto) and [v3](https://github.com/huggingface/text-generation-inference/blob/main/proto/v3/generate.proto). These two versions are almost identical, except for: + +- input chunks support, for text and image data, +- paged attention support + +Here's a diagram that displays the exchanges that follow the router and model server startup. + +```mermaid +sequenceDiagram + + Router->>Model Server: service discovery + Model Server-->>Router: urls for other shards + + Router->>Model Server: get model info + Model Server-->>Router: shard info + + Router->>Model Server: health check + Model Server-->>Router: health OK + + Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size) + Model Server-->>Router: warmup result +``` + +After these are done, the router is ready to receive generate calls from multiple clients. Here's an example. + +```mermaid +sequenceDiagram + participant Client 1 + participant Client 2 + participant Client 3 + participant Router + participant Model Server + + Client 1->>Router: generate_stream + Router->>Model Server: prefill(batch1) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 1 + + Router->>Model Server: decode(cached_batch1) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 2 + + Router->>Model Server: decode(cached_batch1) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 3 + + Client 2->>Router: generate_stream + Router->>Model Server: prefill(batch2) + Note right of Model Server: This stops previous batch, that is restarted + Model Server-->>Router: generations, cached_batch2, timings + Router-->>Client 2: token 1' + + Router->>Model Server: decode(cached_batch1, cached_batch2) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 4 + Router-->>Client 2: token 2' + + Note left of Client 1: Client 1 leaves + Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2) + Model Server-->>Router: filtered batch + + Router->>Model Server: decode(cached_batch2) + Model Server-->>Router: generations, cached_batch2, timings + Router-->>Client 2: token 3' + + Client 3->>Router: generate_stream + Note right of Model Server: This stops previous batch, that is restarted + Router->>Model Server: prefill(batch3) + Note left of Client 1: Client 3 leaves without receiving any batch + Router->>Model Server: clear_cache(batch3) + Note right of Model Server: This stops previous batch, that is restarted + + Router->>Model Server: decode(cached_batch3) + Note right of Model Server: Last token (stopping criteria) + Model Server-->>Router: generations, cached_batch3, timings + Router-->>Client 2: token 4' + + +``` From fb939370a3fb069ac17e1917d64124bead9f5aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 17 Jun 2024 10:49:41 +0200 Subject: [PATCH 054/340] Support different image sizes in prefill in VLMs (#2065) When a batch contained images if different sizes during prefill, the server would fail (see e.g. #2056). Images were processed separately and then concatenated. However, this can fail for images with different sizes. Fix this by preprocessing all images in the batch together, so that the image processor can ensure that all image tensors have compatible sizes. --- .../test_flash_pali_gemma_two_images.json | 61 ++++++++ .../test_idefics/test_idefics_two_images.json | 85 +++++++++++ .../test_flash_idefics2_two_images.json | 133 ++++++++++++++++++ .../models/test_flash_pali_gemma.py | 23 +++ integration-tests/models/test_idefics.py | 21 +++ integration-tests/models/test_idefics2.py | 23 +++ .../models/vlm_causal_lm.py | 57 ++++---- 7 files changed, 376 insertions(+), 27 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json create mode 100644 integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json create mode 100644 integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json new file mode 100644 index 00000000..ab4f3015 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json @@ -0,0 +1,61 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 8, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2502, + "logprob": -1.734375, + "special": false, + "text": "image" + }, + { + "id": 2196, + "logprob": -0.5756836, + "special": false, + "text": " result" + }, + { + "id": 604, + "logprob": -0.007843018, + "special": false, + "text": " for" + }, + { + "id": 12254, + "logprob": -1.7167969, + "special": false, + "text": " chicken" + }, + { + "id": 611, + "logprob": -0.17053223, + "special": false, + "text": " on" + }, + { + "id": 573, + "logprob": -0.7626953, + "special": false, + "text": " the" + }, + { + "id": 8318, + "logprob": -0.02709961, + "special": false, + "text": " beach" + }, + { + "id": 1, + "logprob": -0.20739746, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "image result for chicken on the beach" +} diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json new file mode 100644 index 00000000..a4727707 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json @@ -0,0 +1,85 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 12, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 450, + "logprob": -0.26342773, + "special": false, + "text": " The" + }, + { + "id": 21282, + "logprob": -0.01838684, + "special": false, + "text": " cow" + }, + { + "id": 322, + "logprob": -0.18041992, + "special": false, + "text": " and" + }, + { + "id": 521, + "logprob": -0.62841797, + "special": false, + "text": " ch" + }, + { + "id": 21475, + "logprob": -0.0037956238, + "special": false, + "text": "icken" + }, + { + "id": 526, + "logprob": -0.018737793, + "special": false, + "text": " are" + }, + { + "id": 373, + "logprob": -1.0820312, + "special": false, + "text": " on" + }, + { + "id": 263, + "logprob": -0.5083008, + "special": false, + "text": " a" + }, + { + "id": 25695, + "logprob": -0.07128906, + "special": false, + "text": " beach" + }, + { + "id": 29889, + "logprob": -0.12573242, + "special": false, + "text": "." + }, + { + "id": 32002, + "logprob": -0.0029792786, + "special": true, + "text": "" + }, + { + "id": 2, + "logprob": -0.00024962425, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " The cow and chicken are on a beach." +} diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json new file mode 100644 index 00000000..86c95b29 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -0,0 +1,133 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 415, + "logprob": -0.04421997, + "special": false, + "text": " The" + }, + { + "id": 12072, + "logprob": -0.13500977, + "special": false, + "text": " cow" + }, + { + "id": 349, + "logprob": -0.06750488, + "special": false, + "text": " is" + }, + { + "id": 6328, + "logprob": -0.6352539, + "special": false, + "text": " standing" + }, + { + "id": 356, + "logprob": -0.16186523, + "special": false, + "text": " on" + }, + { + "id": 272, + "logprob": -0.5078125, + "special": false, + "text": " the" + }, + { + "id": 10305, + "logprob": -0.017913818, + "special": false, + "text": " beach" + }, + { + "id": 304, + "logprob": -1.5205078, + "special": false, + "text": " and" + }, + { + "id": 272, + "logprob": -0.029174805, + "special": false, + "text": " the" + }, + { + "id": 13088, + "logprob": -0.003479004, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.0035095215, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.3088379, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.027755737, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.31884766, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.047943115, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.0002925396, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.02935791, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.031219482, + "special": false, + "text": "." + }, + { + "id": 32002, + "logprob": -0.00034475327, + "special": true, + "text": "" + }, + { + "id": 2, + "logprob": -1.1920929e-07, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money." +} diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index d4e83c9f..6be1750c 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle): return flash_pali_gemma_handle.client +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + def get_cow_beach(): with open("integration-tests/images/cow_beach.png", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()) @@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): assert response.generated_text == "beach" assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await flash_pali_gemma.generate( + f"caption![]({chicken})![]({cow_beach})\n", + max_new_tokens=20, + ) + # Is PaliGemma not able to handle two separate images? At least we + # get output showing that both images are used. + assert ( + response.generated_text == "image result for chicken on the beach" + ), f"{repr(response.generated_text)}" + assert response == response_snapshot diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index aeeaffa1..ac807b76 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -23,6 +23,12 @@ def get_chicken(): return f"data:image/png;base64,{encoded_string.decode('utf-8')}" +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + @pytest.mark.asyncio async def test_idefics(idefics, response_snapshot): chicken = get_chicken() @@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot): assert response == response_snapshot +@pytest.mark.asyncio +@pytest.mark.private +async def test_idefics_two_images(idefics, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await idefics.generate( + f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", + max_new_tokens=20, + ) + assert ( + response.generated_text == " The cow and chicken are on a beach." + ), f"{repr(response.generated_text)}" + assert response == response_snapshot + + @pytest.mark.asyncio async def test_idefics_load(idefics, generate_load, response_snapshot): chicken = get_chicken() diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index d34cce34..9aaf6d8a 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -9,6 +9,12 @@ def get_chicken(): return f"data:image/png;base64,{encoded_string.decode('utf-8')}" +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + @pytest.fixture(scope="module") def flash_idefics2_next_handle(launcher): with launcher( @@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot assert response == response_snapshot +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await flash_idefics2_next.generate( + f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", + max_new_tokens=20, + ) + assert ( + response.generated_text + == " The cow is standing on the beach and the chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 20 + assert response == response_snapshot + + @pytest.mark.asyncio @pytest.mark.private async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot): diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 59a6fab1..8b5819d1 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -53,7 +53,9 @@ def image_text_replacement(image_input, config, image_id) -> str: num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info(f"Found {num_features} in image of resolution {height}x{width}") + logger.info( + f"Found {num_features} features in image of resolution {height}x{width}" + ) return "" * num_features elif config.model_type == "paligemma": @@ -133,23 +135,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch): def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): + # Process images first. We need all of them so that the processor + # can make the image splits the same size. And we need the final + # sizes to insert correct number of image tokens. + images = [] + for r in requests: + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + pass + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + if config.model_type == "llava_next": + images.append(image) + else: + images.append([image]) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + if images: + image_inputs = processor.image_processor(images, return_tensors="pt") + else: + image_inputs = None + batch_inputs = [] - image_inputs = [] max_truncation = 0 + image_id = 0 for r in requests: full_text = "" - image_id = 0 for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement(image_input, config, image_id) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") + full_text += image_text_replacement(image_inputs, config, image_id) + image_id += 1 batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) @@ -160,24 +180,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): max_length=max_truncation, add_special_tokens=not config.model_type == "paligemma", )["input_ids"] - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None + return batch_tokenized_inputs, image_inputs @classmethod From 58c743bc908a93f3f32352453a45a2cc3493711a Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 17 Jun 2024 12:09:31 +0200 Subject: [PATCH 055/340] Contributing guide & Code of Conduct (#2074) * Contributing guide & Code of Conduct * Redirect to GitHub's tutorial on PRs --- CODE_OF_CONDUCT.md | 133 +++++++++++++++++++++++++++++++++++++++++++++ CONTRIBUTING.md | 120 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..ef09fa13 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +feedback@huggingface.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..39b57c19 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,120 @@ + + +# Contribute to text-generation-inference + +Everyone is welcome to contribute, and we value everybody's contribution. Code +contributions are not the only way to help the community. Answering questions, helping +others, and improving the documentation are also immensely valuable. + +It also helps us if you spread the word! Reference the library in blog posts +about the awesome projects it made possible, shout out on Twitter every time it has +helped you, or simply ⭐️ the repository to say thank you. + +However you choose to contribute, please be mindful and respect our +[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md). + +**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).** + +## Ways to contribute + +There are several ways you can contribute to text-generation-inference. + +* Fix outstanding issues with the existing code. +* Submit issues related to bugs or desired new features. +* Contribute to the examples or to the documentation. + +> All contributions are equally valuable to the community. 🥰 + +## Fixing outstanding issues + +If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open +a Pull Request! + +## Submitting a bug-related issue or feature request + +Do your best to follow these guidelines when submitting a bug-related issue or a feature +request. It will make it easier for us to come back to you quickly and with good +feedback. + +### Did you find a bug? + +The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter. + +Before you report an issue, we would really appreciate it if you could **make sure the bug was not +already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the +library itself, and not your code. + +Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so +we can quickly resolve it: + +* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies). +* A short, self-contained, code snippet that allows us to reproduce the bug. +* The *full* traceback if an exception is raised. +* Attach any other additional information, like screenshots, you think may help. + +To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag: + +```bash +text-generation-launcher --env +``` + +This will precede the launch of the model with the information relative to your environment. We recommend pasting +that in your issue report. + +### Do you want a new feature? + +If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe: + +1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it + a feature related to something you need for a project? Is it something you worked on and think it could benefit + the community? + + Whatever it is, we'd love to hear about it! + +2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better + we'll be able to help you. +3. Provide a *code snippet* that demonstrates the feature's usage. +4. If the feature is related to a paper, please include a link. + +If your issue is well written we're already 80% of the way there by the time you create it. + +We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE) +to help you get started with your issue. + +## Do you want to implement a new model? + +New models are constantly released and if you want to implement a new model, please provide the following information: + +* A short description of the model and a link to the paper. +* Link to the implementation if it is open-sourced. +* Link to the model weights if they are available. + +If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference! + +## Do you want to add documentation? + +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know +how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be +happy to make the changes or help you make a contribution if you're interested! + +## I want to become a maintainer of the project. How do I get there? + +TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have +motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference +service. + +If you are such an individual (or organization), please reach out to us and let's collaborate. \ No newline at end of file From b3dadbde066eaafd21d19508031d25a53f83654f Mon Sep 17 00:00:00 2001 From: Ziru Niu Date: Mon, 17 Jun 2024 18:10:01 +0800 Subject: [PATCH 056/340] fix build.rs watch files (#2072) --- router/client/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/client/build.rs b/router/client/build.rs index a7ade9b0..210cd603 100644 --- a/router/client/build.rs +++ b/router/client/build.rs @@ -1,7 +1,7 @@ use std::fs; fn main() -> Result<(), Box> { - println!("cargo:rerun-if-changed=../../proto/**"); + println!("cargo:rerun-if-changed=../../proto/"); fs::create_dir_all("src/v2/pb").unwrap_or(()); let mut config = prost_build::Config::new(); From 6b2cbd016955fc8acbf8393c5f2857ddf5e5fe3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 17 Jun 2024 16:40:44 +0200 Subject: [PATCH 057/340] Set maximum grpc message receive size to 2GiB (#2075) * Set maximum grpc message receive size to 2GiB The previous default was 4MiB, which doesn't really work well for multi-modal models. * Update to Rust 1.79.0 * Fixup formatting to make PR pass --- CODE_OF_CONDUCT.md | 2 +- CONTRIBUTING.md | 22 +++++++++++----------- Dockerfile_amd | 2 +- Dockerfile_intel | 2 +- benchmark/src/app.rs | 12 ++++++------ benchmark/src/table.rs | 6 +++--- benchmark/src/utils.rs | 2 +- rust-toolchain.toml | 6 +++--- server/text_generation_server/server.py | 6 +++++- 9 files changed, 32 insertions(+), 28 deletions(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index ef09fa13..b23f3150 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -130,4 +130,4 @@ For answers to common questions about this code of conduct, see the FAQ at [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html [Mozilla CoC]: https://github.com/mozilla/diversity [FAQ]: https://www.contributor-covenant.org/faq -[translations]: https://www.contributor-covenant.org/translations \ No newline at end of file +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39b57c19..d541e47f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,10 +55,10 @@ feedback. The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter. Before you report an issue, we would really appreciate it if you could **make sure the bug was not -already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the -library itself, and not your code. +already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the +library itself, and not your code. -Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so +Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it: * Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies). @@ -79,20 +79,20 @@ that in your issue report. If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe: -1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it - a feature related to something you need for a project? Is it something you worked on and think it could benefit +1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it + a feature related to something you need for a project? Is it something you worked on and think it could benefit the community? Whatever it is, we'd love to hear about it! -2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better +2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you. 3. Provide a *code snippet* that demonstrates the feature's usage. 4. If the feature is related to a paper, please include a link. If your issue is well written we're already 80% of the way there by the time you create it. -We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE) +We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE) to help you get started with your issue. ## Do you want to implement a new model? @@ -107,14 +107,14 @@ If you are willing to contribute the model yourself, let us know so we can help ## Do you want to add documentation? -We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know -how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know +how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be happy to make the changes or help you make a contribution if you're interested! ## I want to become a maintainer of the project. How do I get there? TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference -service. +service. -If you are such an individual (or organization), please reach out to us and let's collaborate. \ No newline at end of file +If you are such an individual (or organization), please reach out to us and let's collaborate. diff --git a/Dockerfile_amd b/Dockerfile_amd index c79bc03c..55da9204 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/Dockerfile_intel b/Dockerfile_intel index cb0e84bb..35362fc9 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,4 +1,4 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index 48ac976a..a0a9313a 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Lowest: {:.2} {unit}", data.iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Highest: {:.2} {unit}", data.iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>( let min_latency: f64 = *latency_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_latency: f64 = *latency_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let min_throughput: f64 = *throughput_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_throughput: f64 = *throughput_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); // Char min max values let min_x = if zoom { diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index e18d7310..1585a25f 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { let min = data .iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max = data .iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); (average, *min, *max) } fn px(data: &[f64], p: u32) -> f64 { let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; - *data.get(i).unwrap_or(&std::f64::NAN) + *data.get(i).unwrap_or(&f64::NAN) } fn format_value(value: f64, unit: &'static str) -> String { diff --git a/benchmark/src/utils.rs b/benchmark/src/utils.rs index d096d655..20469991 100644 --- a/benchmark/src/utils.rs +++ b/benchmark/src/utils.rs @@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap Date: Thu, 20 Jun 2024 07:56:16 +0200 Subject: [PATCH 058/340] Support exl2-quantized Qwen2 models (#2085) Fixes #2081. --- .../custom_modeling/flash_qwen2_modeling.py | 27 +++---------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index b1de58b2..df5a8ae9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -40,31 +40,12 @@ def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 - weight = weights.get_multi_weights_col( + return TensorParallelColumnLinear.load_multi( + config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, - ) - - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - w = [ - weights.get_sharded(f"{p}.bias", dim=0) - for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] - ] - bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) - - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) + weights=weights, + bias=True, ) From c61ef1ce85c4faec7d8aab2713e634e26e735087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 20 Jun 2024 09:56:04 +0200 Subject: [PATCH 059/340] Factor out sharding of packed tensors (#2059) For Phi-3-Small I need to shard a packed QKV bias tensor, for which I implemented the `Weights.get_packed_sharded` method. However, this method can also replace the `Weights._get_qweight` method and the custom sharding code from `Weights.get_weights_col_packed`. --- .../text_generation_server/utils/weights.py | 99 +++++++++++-------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 45cfc073..e6142525 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -130,29 +130,57 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]): - slice_ = self._get_slice(name) - total_size = slice_.get_shape()[1] + def get_packed_sharded( + self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + ) -> torch.Tensor: + """ + Get a shard from a tensor that packs multiple tensors. + + When a tensor packs multiple tensors (such as QKV or an up + projection + gate projection), sharding with `get_sharded` is not + safe since it would not split the packed tensors across shards. + + This method shards a tensor, such that the packed tensors are + split across shards. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. + """ + slice_ = self._get_slice(tensor_name) + total_size = slice_.get_shape()[dim] block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) world_size = self.process_group.size() rank = self.process_group.rank() - weights = [] + tensors = [] block_offset = 0 for block_size in block_sizes: assert ( block_size % world_size == 0 - ), f"Prepacked qkv cannot be sharded across {world_size} shards" + ), f"Prepacked tensor cannot be sharded across {world_size} shards" shard_block_size = block_size // world_size start = rank * shard_block_size stop = (rank + 1) * shard_block_size - weights.append(slice_[:, block_offset + start : block_offset + stop]) + if dim == 0: + tensor = slice_[block_offset + start : block_offset + stop] + elif dim == 1: + tensor = slice_[:, block_offset + start : block_offset + stop] + else: + raise NotImplementedError("Currently only dim=0 or dim=1 is supported") + tensors.append(tensor) block_offset += block_size + tensor = torch.cat(tensors, dim=dim) + tensor = tensor.to(device=self.device) - weight = torch.cat(weights, dim=1) - weight = weight.to(device=self.device) - return weight + # Avoid casting quantizer dtypes. + if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) + + return tensor def get_weights_col_packed_qkv( self, @@ -185,7 +213,9 @@ class Weights: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + qweight = self.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -193,8 +223,12 @@ class Weights: gptq_params = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) - scales = self._get_qweight(f"{prefix}.scales", block_sizes) + qzeros = self.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and gptq_params.quant_method == "gptq": @@ -237,13 +271,17 @@ class Weights: if quant_method == "gptq": gptq_params = self._get_gptq_params() try: - qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + qweight = self.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - scales = self._get_qweight(f"{prefix}.scales", block_sizes) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) g_idx = self.get_tensor(f"{prefix}.g_idx") weight = repack_gptq_for_marlin( qweight=qweight, @@ -257,34 +295,17 @@ class Weights: ) else: - B = self._get_qweight(f"{prefix}.B", block_sizes) - s = self._get_qweight(f"{prefix}.s", block_sizes) + B = self.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = self.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) weight = MarlinWeight(B=B, s=s) else: - slice_ = self._get_slice(f"{prefix}.weight") - total_size = slice_.get_shape()[0] - block_sizes = _blocks_to_block_sizes( - total_size=total_size, blocks=block_sizes + weight = self.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes ) - - world_size = self.process_group.size() - rank = self.process_group.rank() - - tensors = [] - block_offset = 0 - for block_size in block_sizes: - assert ( - block_size % world_size == 0 - ), f"Prepacked weights cannot be sharded across {world_size} shards" - shard_block_size = block_size // world_size - start = rank * shard_block_size - stop = (rank + 1) * shard_block_size - tensor = slice_[block_offset + start : block_offset + stop] - tensors.append(tensor) - block_offset += block_size - weight = torch.cat(tensors, dim=0) - weight = weight.to(device=self.device) - weight = weight.to(dtype=self.dtype) return weight def get_weights_col(self, prefix: str, quantize: str): From f0ed8d294f56ef4b3d136f74b15eb0cf008d55bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 21 Jun 2024 15:28:51 +0200 Subject: [PATCH 060/340] Fix `text-generation-server quantize` (#2103) The subcommand did not work due to some broken imports. --- server/text_generation_server/cli.py | 2 +- server/text_generation_server/layers/gptq/quantize.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 430323bc..5d25bfc5 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -316,7 +316,7 @@ def quantize( logger_level=logger_level, json_output=json_output, ) - from text_generation_server.utils.gptq.quantize import quantize + from text_generation_server.layers.gptq.quantize import quantize quantize( model_id=model_id, diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index ca113d8f..8d029817 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files -from text_generation_server.utils.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional From d930724e82cd5368796e95456a2893ad174a062d Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 21 Jun 2024 14:28:26 -0400 Subject: [PATCH 061/340] feat: sort cuda graphs in descending order (#2104) --- server/text_generation_server/models/globals.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 11a9f030..970a673b 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -15,6 +15,13 @@ if cuda_graphs is not None: else: cuda_graphs = None + +# sorting the cuda graphs in descending order helps reduce the +# memory impact and results in less memory usage +if cuda_graphs is not None: + cuda_graphs.sort(reverse=True) + + CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. From b6a59e2f91f74dd32d4a90613ff922cd4888e318 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 24 Jun 2024 18:08:34 +0200 Subject: [PATCH 062/340] New runner. Manual squash. (#2110) * New runner. Manual squash. * Network host. * Put back trufflehog with proper extension. * No network host ? * Moving buildx install after tailscale ? * 1.79 --- .github/workflows/ci_build.yaml | 36 ++++++++++++++++ .github/workflows/integration_tests.yaml | 41 +++++++++++++++++++ .../{trufflehog.yml => trufflehog.yaml} | 0 integration-tests/conftest.py | 18 ++++++-- 4 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/ci_build.yaml create mode 100644 .github/workflows/integration_tests.yaml rename .github/workflows/{trufflehog.yml => trufflehog.yaml} (100%) diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml new file mode 100644 index 00000000..754c4850 --- /dev/null +++ b/.github/workflows/ci_build.yaml @@ -0,0 +1,36 @@ +name: CI build + +on: + push: + branches: + - 'main' + tags: + - 'v*' + pull_request: + paths: + - ".github/workflows/build.yaml" + - "integration-tests/**" + - "server/**" + - "proto/**" + - "router/**" + - "launcher/**" + - "Cargo.lock" + - "rust-toolchain.toml" + - "Dockerfile" + - "Dockerfile_amd" + - "Dockerfile_intel" + branches: + - 'main' + +jobs: + build: + strategy: + # super important if you want to see all results, even if one fails + # fail-fast is true by default + fail-fast: false + matrix: + hardware: ["cuda", "rocm", "intel"] + uses: ./.github/workflows/build.yaml # calls the one above ^ + with: + hardware: ${{ matrix.hardware }} + secrets: inherit diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml new file mode 100644 index 00000000..4e111afe --- /dev/null +++ b/.github/workflows/integration_tests.yaml @@ -0,0 +1,41 @@ +name: Integration tests + +on: + workflow_call: + inputs: + docker_image: + type: string + description: Hardware + required: true + docker_devices: + type: string + description: Hardware + runs_on: + type: string + required: true + description: Hardware to run integration tests +jobs: + integration_tests: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + runs-on: ${{ inputs.runs_on }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Install + run: | + make install-integration-tests + - name: Run tests + run: | + export DOCKER_VOLUME=/mnt/cache + export DOCKER_IMAGE=${{ inputs.docker_image }} + export DOCKER_DEVICES=${{ inputs.docker_devices }} + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + pytest -s -vv integration-tests diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yaml similarity index 100% rename from .github/workflows/trufflehog.yml rename to .github/workflows/trufflehog.yaml diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 2ef85da6..0b239484 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -34,6 +34,7 @@ from text_generation.types import ( DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") +DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") class ResponseComparator(JSONSnapshotExtension): @@ -453,6 +454,18 @@ def launcher(event_loop): if DOCKER_VOLUME: volumes = [f"{DOCKER_VOLUME}:/data"] + if DOCKER_DEVICES: + devices = DOCKER_DEVICES.split(",") + visible = os.getenv("ROCR_VISIBLE_DEVICES") + if visible: + env["ROCR_VISIBLE_DEVICES"] = visible + device_requests = [] + else: + devices = [] + device_requests = [ + docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) + ] + container = client.containers.run( DOCKER_IMAGE, command=args, @@ -460,9 +473,8 @@ def launcher(event_loop): environment=env, auto_remove=False, detach=True, - device_requests=[ - docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) - ], + device_requests=device_requests, + devices=devices, volumes=volumes, ports={"80/tcp": port}, shm_size="1G", From 4b25048b75e0627f2418e171e6196fae2de5f84f Mon Sep 17 00:00:00 2001 From: ur4t <46435411+ur4t@users.noreply.github.com> Date: Tue, 25 Jun 2024 00:16:36 +0800 Subject: [PATCH 063/340] Fix cargo-chef prepare (#2101) * Fix cargo-chef prepare In prepare stage, cargo-chef reads Cargo.lock and transforms it accordingly. If Cargo.lock is not present, cargo-chef will generate a new one first, which might vary a lot and invalidate docker build caches. * Fix Dockerfile_amd and Dockerfile_intel --- Dockerfile | 1 + Dockerfile_amd | 1 + Dockerfile_intel | 1 + 3 files changed, 3 insertions(+) diff --git a/Dockerfile b/Dockerfile index 6492d723..8bc87a0e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef WORKDIR /usr/src FROM chef as planner +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto diff --git a/Dockerfile_amd b/Dockerfile_amd index 55da9204..e325bb65 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -5,6 +5,7 @@ WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse FROM chef as planner +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto diff --git a/Dockerfile_intel b/Dockerfile_intel index 35362fc9..131f49ba 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -4,6 +4,7 @@ WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse FROM chef as planner +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto From 931ff16c7a6c37fda910f88cd0099826a593607e Mon Sep 17 00:00:00 2001 From: Lucain Date: Tue, 25 Jun 2024 09:23:12 +0200 Subject: [PATCH 064/340] Support `HF_TOKEN` environment variable (#2066) * Support HF_TOKEN environement variable * Load test. --------- Co-authored-by: Nicolas Patry --- benchmark/src/main.rs | 2 +- .../basic_tutorials/gated_model_access.md | 8 ++-- integration-tests/conftest.py | 38 +++++++++---------- launcher/src/main.rs | 6 +-- router/src/main.rs | 2 +- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index b9d80b7a..603b4087 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -147,7 +147,7 @@ fn main() -> Result<(), Box> { tracing::info!("Downloading tokenizer"); // Parse Huggingface hub token - let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok(); // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index b49c59c9..ef3a1db7 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -2,13 +2,13 @@ If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens) -If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example: +If you're using the CLI, set the `HF_TOKEN` environment variable. For example: ``` -export HUGGING_FACE_HUB_TOKEN= +export HF_TOKEN= ``` -If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below. +If you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below. ```bash model=meta-llama/Llama-2-7b-chat-hf @@ -17,7 +17,7 @@ token= docker run --gpus all \ --shm-size 1g \ - -e HUGGING_FACE_HUB_TOKEN=$token \ + -e HF_TOKEN=$token \ -p 8080:80 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ --model-id $model diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 0b239484..13337165 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -1,38 +1,38 @@ -import sys -import subprocess -import contextlib -import pytest import asyncio -import os -import docker +import contextlib import json import math +import os +import random +import re import shutil +import subprocess +import sys import tempfile import time -import random +from typing import Dict, List, Optional -from docker.errors import NotFound -from typing import Optional, List, Dict -from syrupy.extensions.json import JSONSnapshotExtension +import docker +import pytest from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError - +from docker.errors import NotFound +from syrupy.extensions.json import JSONSnapshotExtension from text_generation import AsyncClient from text_generation.types import ( - Response, - Details, - InputToken, - Token, BestOfSequence, - Grammar, ChatComplete, ChatCompletionChunk, ChatCompletionComplete, Completion, + Details, + Grammar, + InputToken, + Response, + Token, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) -HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) +HF_TOKEN = os.getenv("HF_TOKEN", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") @@ -447,8 +447,8 @@ def launcher(event_loop): if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" - if HUGGING_FACE_HUB_TOKEN is not None: - env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN + if HF_TOKEN is not None: + env["HF_TOKEN"] = HF_TOKEN volumes = [] if DOCKER_VOLUME: diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 21ca63d4..51dabcf1 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -595,7 +595,7 @@ fn shard_manager( // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HF_TOKEN".into(), api_token.into())) }; // Detect rope scaling @@ -929,7 +929,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HF_TOKEN".into(), api_token.into())) }; // If args.weights_cache_override is some, pass it to the download process @@ -1231,7 +1231,7 @@ fn spawn_webserver( // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HF_TOKEN".into(), api_token.into())) }; // Parse Compute type diff --git a/router/src/main.rs b/router/src/main.rs index c4203dbc..013176f3 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -156,7 +156,7 @@ async fn main() -> Result<(), RouterError> { }); // Parse Huggingface hub token - let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok(); // Tokenizer instance // This will only be used to validate payloads From 76c6a5ca2a5cd46c82b465af864ec13db4a1e419 Mon Sep 17 00:00:00 2001 From: KevinDuffy94 Date: Tue, 25 Jun 2024 08:33:01 +0100 Subject: [PATCH 065/340] Add OTLP Service Name Environment Variable (#2076) * Adding Service Name Environment variable for https://github.com/huggingface/text-generation-inference/issues/2069 * Update Docs * Update README.md * Update Launcher Docs * Update Launcher Docs Removing Option --- docs/source/architecture.md | 4 ++++ docs/source/basic_tutorials/launcher.md | 7 +++++++ launcher/src/main.rs | 18 +++++++++++++++++- router/src/main.rs | 10 +++++++--- server/text_generation_server/cli.py | 3 ++- server/text_generation_server/tracing.py | 6 ++---- 6 files changed, 39 insertions(+), 9 deletions(-) diff --git a/docs/source/architecture.md b/docs/source/architecture.md index b7885879..a8418817 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -70,6 +70,8 @@ Options: [env: JSON_OUTPUT=] --otlp-endpoint [env: OTLP_ENDPOINT=] + --otlp-service-name + [env: OTLP_SERVICE_NAME=] --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] --ngrok @@ -138,6 +140,8 @@ Serve's command line parameters on the TGI repository are these: │ --logger-level TEXT [default: INFO] │ │ --json-output --no-json-output [default: no-json-output] │ │ --otlp-endpoint TEXT [default: None] │ +│ --otlp-service-name TEXT [default: │ +│ text-generation-inference...│ │ --help Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 9246093e..f6175925 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -336,6 +336,13 @@ Options: --otlp-endpoint [env: OTLP_ENDPOINT=] +``` +## OTLP_SERVICE_NAME +```shell + --otlp-service-name + [env: OTLP_SERVICE_NAME=] + [default: text-generation-inference.router] + ``` ## CORS_ALLOW_ORIGIN ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 51dabcf1..a72155c7 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -415,6 +415,9 @@ struct Args { #[clap(long, env)] otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] cors_allow_origin: Vec, #[clap(long, env)] @@ -485,6 +488,7 @@ fn shard_manager( max_batch_size: Option, max_input_tokens: usize, otlp_endpoint: Option, + otlp_service_name: String, log_level: LevelFilter, status_sender: mpsc::Sender, shutdown: Arc, @@ -549,12 +553,16 @@ fn shard_manager( (None, Some(factor)) => Some((RopeScaling::Linear, factor)), }; - // OpenTelemetry + // OpenTelemetry Endpoint if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); shard_args.push(otlp_endpoint); } + // OpenTelemetry Service Name + shard_args.push("--otlp-service-name".to_string()); + shard_args.push(otlp_service_name); + // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. shard_args.push("--max-input-tokens".to_string()); shard_args.push(max_input_tokens.to_string()); @@ -1039,6 +1047,7 @@ fn spawn_shards( let shutdown = shutdown.clone(); let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); + let otlp_service_name = args.otlp_service_name.clone(); let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; @@ -1078,6 +1087,7 @@ fn spawn_shards( max_batch_size, max_input_tokens, otlp_endpoint, + otlp_service_name, max_log_level, status_sender, shutdown, @@ -1211,6 +1221,12 @@ fn spawn_webserver( router_args.push(otlp_endpoint); } + // OpenTelemetry + let otlp_service_name = args.otlp_service_name; + router_args.push("--otlp-service-name".to_string()); + router_args.push(otlp_service_name); + + // CORS origins for origin in args.cors_allow_origin.into_iter() { router_args.push("--cors-allow-origin".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index 013176f3..dcb9ce99 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -65,6 +65,8 @@ struct Args { json_output: bool, #[clap(long, env)] otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] @@ -107,6 +109,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, json_output, otlp_endpoint, + otlp_service_name, cors_allow_origin, ngrok, ngrok_authtoken, @@ -117,7 +120,7 @@ async fn main() -> Result<(), RouterError> { } = args; // Launch Tokio runtime - init_logging(otlp_endpoint, json_output); + init_logging(otlp_endpoint, otlp_service_name, json_output); // Validate args if max_input_tokens >= max_total_tokens { @@ -367,10 +370,11 @@ async fn main() -> Result<(), RouterError> { /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - otlp_service_name service name to appear in APM /// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) /// - LOG_FORMAT may be TEXT or JSON (default to TEXT) /// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) -fn init_logging(otlp_endpoint: Option, json_output: bool) { +fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { let mut layers = Vec::new(); // STDOUT/STDERR layer @@ -401,7 +405,7 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { trace::config() .with_resource(Resource::new(vec![KeyValue::new( "service.name", - "text-generation-inference.router", + otlp_service_name, )])) .with_sampler(Sampler::AlwaysOn), ) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 5d25bfc5..18cad071 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -42,6 +42,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + otlp_service_name: str = "text-generation-inference.server", max_input_tokens: Optional[int] = None, ): if sharded: @@ -76,7 +77,7 @@ def serve( # Setup OpenTelemetry distributed tracing if otlp_endpoint is not None: - setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) + setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value diff --git a/server/text_generation_server/tracing.py b/server/text_generation_server/tracing.py index bf03c379..bc7a04ee 100644 --- a/server/text_generation_server/tracing.py +++ b/server/text_generation_server/tracing.py @@ -54,10 +54,8 @@ class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor): ) -def setup_tracing(shard: int, otlp_endpoint: str): - resource = Resource.create( - attributes={"service.name": f"text-generation-inference.server-{shard}"} - ) +def setup_tracing(otlp_service_name: str, otlp_endpoint: str): + resource = Resource.create(attributes={"service.name": otlp_service_name}) span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) span_processor = BatchSpanProcessor(span_exporter) From 1952a0b03b391929b146e25488018c37e18cb60e Mon Sep 17 00:00:00 2001 From: Jeff Date: Tue, 25 Jun 2024 04:10:32 -0400 Subject: [PATCH 066/340] corrected Pydantic warning. (#2095) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * corrected Pydantic warning. * Update clients/python/text_generation/types.py Co-authored-by: Daniël de Kok --------- Co-authored-by: Nicolas Patry Co-authored-by: Daniël de Kok --- clients/python/text_generation/types.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index eb872ee6..497468d9 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, ConfigDict from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -452,5 +452,9 @@ class StreamResponse(BaseModel): # Inference API currently deployed model class DeployedModel(BaseModel): + # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members + # with model_ prefixes, since this disables guardrails for colliding fields: + # https://github.com/pydantic/pydantic/issues/9177 + model_config = ConfigDict(protected_namespaces=()) model_id: str sha: str From e49aed47138e8d49288e8107b56a55be2c0ef667 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 25 Jun 2024 16:15:46 +0800 Subject: [PATCH 067/340] use xpu-smi to dump used memory (#2047) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * use xpu-smi to dump used memory xpu use "ZE_AFFINITY_MASK" to control card, usage is like CUDA_VISIBLE_DEVICES Signed-off-by: Wang, Yi A * Update server/text_generation_server/utils/import_utils.py Co-authored-by: Daniël de Kok --------- Signed-off-by: Wang, Yi A Co-authored-by: Daniël de Kok --- Dockerfile_intel | 2 +- launcher/src/main.rs | 9 ++++++--- server/text_generation_server/utils/import_utils.py | 9 +++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 131f49ba..f09614d4 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -49,7 +49,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build +RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a72155c7..90b587cc 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -763,7 +763,10 @@ fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver fn num_cuda_devices() -> Option { let devices = match env::var("CUDA_VISIBLE_DEVICES") { Ok(devices) => devices, - Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, + } }; let n_devices = devices.split(',').count(); Some(n_devices) @@ -836,9 +839,9 @@ fn find_num_shards( let num_shard = match (sharded, num_shard) { (Some(true), None) => { // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK"); let n_devices = num_cuda_devices() - .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); + .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set"); if n_devices <= 1 { return Err(LauncherError::NotEnoughCUDADevices(format!( "`sharded` is true but only found {n_devices} CUDA devices" diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index d79e36c2..c3929392 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,5 +1,6 @@ import torch from loguru import logger +import subprocess def is_xpu_available(): @@ -19,8 +20,12 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): - total_gpu_memory = torch.xpu.get_device_properties(device).total_memory - free_memory = int(total_gpu_memory * 0.5) + total_memory = torch.xpu.get_device_properties(device).total_memory + device_id = device.index + query = f"xpu-smi dump -d {device_id} -m 18 -n 1" + output = subprocess.check_output(query.split()).decode("utf-8").split("\n") + used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 + free_memory = int(total_memory * 0.95 - used_memory) return free_memory From a9faabc3742440aa6b196666b6d5eea288edc25a Mon Sep 17 00:00:00 2001 From: sunxichen Date: Tue, 25 Jun 2024 16:59:50 +0800 Subject: [PATCH 068/340] fix ChatCompletion and ChatCompletionChunk object string not compatible with standard openai api (#2089) Co-authored-by: sunxichen --- router/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index b0b93c13..5d201937 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -570,7 +570,7 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "text_completion".into(), + object: "chat.completion".into(), created, model, system_fingerprint, @@ -682,7 +682,7 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "text_completion".to_string(), + object: "chat.completion.chunk".to_string(), created, model, system_fingerprint, From 0d879fe66e9c8837c54ee717a82553ea6f7c5b34 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 25 Jun 2024 18:21:29 +0800 Subject: [PATCH 069/340] Cpu tgi (#1936) * add CPU tgi support Signed-off-by: Wang, Yi A * ipex distributed ops support Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A Co-authored-by: Funtowicz Morgan --- Dockerfile_intel | 87 ++++++++++++++++++- .../layers/attention/__init__.py | 4 +- .../layers/attention/xpu.py | 5 +- .../layers/layernorm.py | 24 ++--- .../text_generation_server/layers/rotary.py | 6 +- .../layers/tensor_parallel.py | 31 +++++-- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../models/flash_causal_lm.py | 49 +++++++---- .../models/flash_gpt2.py | 5 +- .../models/flash_llama.py | 5 +- .../models/flash_mistral.py | 5 +- .../models/flash_neox.py | 5 +- .../text_generation_server/models/flash_rw.py | 5 +- .../models/flash_santacoder.py | 5 +- server/text_generation_server/utils/dist.py | 35 ++++---- .../utils/import_utils.py | 19 ++-- 17 files changed, 221 insertions(+), 77 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index f09614d4..a41fbc1e 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,3 +1,5 @@ +ARG PLATFORM=xpu + FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src @@ -37,7 +39,8 @@ RUN cargo build --profile release-opt # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base + +FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu USER root # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it @@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server COPY proto proto @@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher -# Final image -FROM base +# Text Generation Inference base image for Intel-cpu +FROM ubuntu:22.04 as cpu + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + make \ + g++ \ + git \ + wget \ + cmake + +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN conda install -c conda-forge gperftools mkl + +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl + +WORKDIR /usr/src + +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a + +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 + +RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install + +RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . + +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so +ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib +ENV KMP_BLOCKTIME=1 +ENV KMP_TPAUSE=0 +ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist +ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist +ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_intel.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + +FROM ${PLATFORM} as final ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e6cb4edf..a53c8e3b 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": @@ -7,7 +7,7 @@ if SYSTEM == "cuda": from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING -elif SYSTEM == "xpu": +elif IPEX_AVAIL: from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index 8b6cb87b..bfab0119 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -1,5 +1,6 @@ import intel_extension_for_pytorch as ipex import torch +from text_generation_server.models.flash_causal_lm import BLOCK_SIZE SUPPORTS_WINDOWING = False @@ -56,8 +57,6 @@ def paged_attention( input_lengths: torch.Tensor, max_s: int, ): - query = query.contiguous() - block_size = value_cache.shape[3] return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -67,7 +66,7 @@ def paged_attention( softmax_scale, block_tables, input_lengths, - block_size, + BLOCK_SIZE, max_s, None, ) diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index c4aa6c7d..93e83dfa 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -3,6 +3,7 @@ from torch import nn from accelerate import init_empty_weights from text_generation_server.utils.import_utils import ( SYSTEM, + IPEX_AVAIL, ) @@ -82,18 +83,20 @@ elif SYSTEM == "rocm": return super().forward(hidden_states), residual -elif SYSTEM == "xpu": +elif IPEX_AVAIL: import intel_extension_for_pytorch as ipex class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - res_out = hidden_states out = ipex.llm.functional.add_layer_norm( - residual, hidden_states, self.weight, self.bias, self.eps, True + residual, + hidden_states, + self.weight, + self.bias, + self.eps, + residual is not None, ) - if residual is not None: - res_out = residual - return out, res_out + return out, residual if residual is not None else hidden_states class FastRMSNorm(nn.Module): @@ -109,19 +112,16 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "xpu": - residual_out = hidden_states + if IPEX_AVAIL: out = ipex.llm.functional.add_rms_norm( residual, hidden_states, self.weight, None, self.variance_epsilon, - True, + residual is not None, ) - if residual is not None: - residual_out = residual - return out, residual_out + return out, residual if residual is not None else hidden_states elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index c2f12189..1892cf69 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,14 +2,14 @@ import os import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops -elif SYSTEM == "xpu": +elif IPEX_AVAIL: import intel_extension_for_pytorch as ipex @@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "xpu": + elif IPEX_AVAIL: ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 6005f737..510dc2c6 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -3,6 +3,10 @@ from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.utils.import_utils import IPEX_AVAIL + +if IPEX_AVAIL: + import intel_extension_for_pytorch as ipex class LayerConcat(torch.nn.Module): @@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + if IPEX_AVAIL: + ipex.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - torch.distributed.all_gather(world_output, output, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_gather(world_output, output, group=self.process_group) + else: + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - torch.distributed.all_reduce(out, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 94cf6452..56292250 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,9 +20,9 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import IPEX_AVAIL -if SYSTEM != "xpu": +if not IPEX_AVAIL: from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3900bf73..6ea95411 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import IPEX_AVAIL -if SYSTEM != "xpu": +if not IPEX_AVAIL: from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d3710..633b066b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK @@ -773,21 +773,38 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + if IPEX_AVAIL and SYSTEM == "cpu": + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index ef129e92..43f374e5 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -37,6 +37,9 @@ class FlashGPT2(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e27f0da2..e023c4e0 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -17,7 +17,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL class FlashLlama(FlashCausalLM): @@ -37,6 +37,9 @@ class FlashLlama(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0fdda6d2..0c570487 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -16,7 +16,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -41,6 +41,9 @@ class BaseFlashMistral(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index d3871c2f..69d47e57 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -36,6 +36,9 @@ class FlashNeoXSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 187f26a8..e087dcf1 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -37,6 +37,9 @@ class FlashRWSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a8d84fca..9626af60 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -40,6 +40,9 @@ class FlashSantacoderSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 3625e6f2..7d387563 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from datetime import timedelta from loguru import logger +from text_generation_server.utils.import_utils import IPEX_AVAIL # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -57,14 +58,7 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - try: - import oneccl_bindings_for_pytorch - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" + backend = "gloo" options = None if WORLD_SIZE == 1: @@ -75,13 +69,24 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if IPEX_AVAIL: + import intel_extension_for_pytorch as ipex + + ipex.distributed.init_process_group( + backend="ccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) else: logger.warning("torch.distributed is already initialized.") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index c3929392..a244417a 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -3,13 +3,12 @@ from loguru import logger import subprocess -def is_xpu_available(): +def is_ipex_available(): try: import intel_extension_for_pytorch except ImportError: return False - - return hasattr(torch, "xpu") and torch.xpu.is_available() + return True def get_cuda_free_memory(device, memory_fraction): @@ -29,6 +28,16 @@ def get_xpu_free_memory(device, memory_fraction): return free_memory +def get_cpu_free_memory(device, memory_fraction): + import psutil + from text_generation_server.utils.dist import WORLD_SIZE + + mem = psutil.virtual_memory() + free_memory = int(mem.available * 0.95 / WORLD_SIZE) + return free_memory + + +IPEX_AVAIL = is_ipex_available() SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" @@ -40,7 +49,7 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory -elif is_xpu_available(): +elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available(): SYSTEM = "xpu" empty_cache = torch.xpu.empty_cache synchronize = torch.xpu.synchronize @@ -53,5 +62,5 @@ else: empty_cache = noop synchronize = noop - get_free_memory = noop + get_free_memory = get_cpu_free_memory logger.info(f"Detected system {SYSTEM}") From 1f70bb75e35fee9a004680f4b379a65b2783f558 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 25 Jun 2024 06:22:59 -0400 Subject: [PATCH 070/340] feat: add simple tests for weights (#2092) * feat: add simple tests for weights * fix: adjust types and add tests * fix: adjust so all tests pass * feat: improve weight tests * fix: add missing tests and renames * fix: tweak shapes --- server/tests/utils/test_weights.py | 1152 ++++++++++++++++++++++++++++ 1 file changed, 1152 insertions(+) create mode 100644 server/tests/utils/test_weights.py diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py new file mode 100644 index 00000000..8f88b1f8 --- /dev/null +++ b/server/tests/utils/test_weights.py @@ -0,0 +1,1152 @@ +import pytest +import torch +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.marlin import MarlinWeight +from types import SimpleNamespace +from typing import List, Optional, Dict, Union +from pathlib import Path + +dummy_file_system = { + "test_weights": { + "layer.0.weight": torch.tensor( + [ + [1, 2], + [3, 4], + ], + dtype=torch.float32, + ), + }, + "test_weights_2": { + "layer.1337.weight": torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_col_packed": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_multi_weights_col": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_multi_weights_row": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + }, + "test_get_multi_weights_row_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_multi_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_packed_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_packed_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_row_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_row_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, + "test_get_multi_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, + "test_get_weights_col_packed_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, +} + + +class MockSlice: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return self.tensor.shape + + def __getitem__(self, idx): + return self.tensor[idx] + + +def mock_get_slice(tensor_name, filename): + tensor = dummy_file_system[filename][tensor_name] + return MockSlice(tensor) + + +def mock_handle(filename, device, dtype): + return SimpleNamespace( + get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename) + ) + + +class MockSafeOpen: + def __init__(self, filename, framework, dummy_fs): + self.filename = filename + self.framework = framework + self.dummy_fs = dummy_fs + + def keys(self): + return list(self.dummy_fs[self.filename].keys()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class MockWeights(Weights): + def __init__( + self, + filenames: List[Union[Path, str]], + device, + dtype, + process_group, + dummy_fs, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + ): + routing = {} + self.dummy_fs = dummy_fs + for filename in filenames: + with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self.prefix = prefix + self._handles = {} + + def _get_handle(self, filename: Union[Path, str]): + if filename in self._handles: + return self._handles[filename] + else: + handle = mock_handle(filename, self.device, self.dtype) + self._handles[filename] = handle + return handle + + def get_shape(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).tensor + + +dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1) + + +def test_weights(): + weights = MockWeights( + [ + "test_weights", + "test_weights_2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + assert weights.get_shape("layer.0.weight") == (2, 2) + assert weights.get_tensor("layer.1337.weight").shape == (2, 4) + + +def test_get_tensor(): + weights = MockWeights( + [ + "test_weights", + "test_weights_2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + assert torch.allclose( + weights.get_tensor("layer.0.weight"), + torch.tensor( + [ + [1, 2], + [3, 4], + ], + dtype=torch.float32, + ), + ) + assert torch.allclose( + weights.get_tensor("layer.1337.weight"), + torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + block_sizes = 2 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size_arr(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + block_sizes = [1, 1] + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_multi_weights_col(): + weights = MockWeights( + [ + "test_get_multi_weights_col", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight", "weight"] + quantize = None + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_multi_weights_row(): + weights = MockWeights( + [ + "test_get_multi_weights_row", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + assert torch.allclose( + w, + torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + dtype=torch.float32, + ), + ) + + +# test_get_weights_col + + +def test_get_weights_col_awq(): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor( + [[100.0, 100.0], [100.0, 100.0]], + dtype=torch.float16, + ), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_gtpq(): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "gptq" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_marlin(): + weights = MockWeights( + [ + "test_get_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_weights_col_packed + + +def test_get_weights_col_packed_awq(): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +@pytest.mark.skip(reason="Review expected functionality") +def test_get_weights_col_packed_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_packed_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_packed_gptq(): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_packed_marlin(): + weights = MockWeights( + [ + "test_get_weights_col_packed_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + print(expected_weight) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_col + + +def test_get_multi_weights_col_awq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "awq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + try: + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + except ValueError as e: + assert e.args[0] == "get_multi_weights_col is not supported for exl2" + + +def test_get_multi_weights_col_gptq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_marlin(): + weights = MockWeights( + [ + "test_get_multi_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_row + + +def test_get_multi_weights_row_awq(): + weights = MockWeights( + [ + "test_get_multi_weights_row_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_row_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_row_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + print(w) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_multi_weights_row_gptq(): + weights = MockWeights( + [ + "test_get_multi_weights_row_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "gptq" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_row_marlin(): + weights = MockWeights( + [ + "test_get_multi_weights_row_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" From d62668503906bd75c486ce658c349ba9db2236bb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jun 2024 13:20:57 +0200 Subject: [PATCH 071/340] Removing IPEX_AVAIL. (#2115) * Removing IPEX_AVAIL. Chose to unify CPU and XPU under `ipex`. Most code is exactly similar except for a very few spots. The biggest number of spots is the kv-cache layout and the flash_xxx.py files. Since those files should be removed soon and factored away, we should not need them. * Forgot a few places. * Unrelated change. * Fixing HF_TOKEN. * HF_TOKEN --- .github/workflows/integration_tests.yaml | 2 +- clients/python/text_generation/types.py | 2 +- .../layers/attention/__init__.py | 6 ++--- .../layers/attention/{xpu.py => ipex.py} | 0 .../layers/layernorm.py | 5 ++-- .../text_generation_server/layers/rotary.py | 6 ++--- .../layers/tensor_parallel.py | 12 +++++----- .../custom_modeling/flash_dbrx_modeling.py | 4 ++-- .../custom_modeling/flash_mixtral_modeling.py | 4 ++-- .../models/flash_causal_lm.py | 9 +++----- .../models/flash_gpt2.py | 12 +++++----- .../models/flash_llama.py | 12 +++++----- .../models/flash_mistral.py | 12 +++++----- .../models/flash_neox.py | 12 +++++----- .../text_generation_server/models/flash_rw.py | 12 +++++----- .../models/flash_santacoder.py | 12 +++++----- server/text_generation_server/utils/dist.py | 4 ++-- .../utils/import_utils.py | 23 +++++++++++-------- 18 files changed, 75 insertions(+), 74 deletions(-) rename server/text_generation_server/layers/attention/{xpu.py => ipex.py} (100%) diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 4e111afe..59a8d304 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -37,5 +37,5 @@ jobs: export DOCKER_VOLUME=/mnt/cache export DOCKER_IMAGE=${{ inputs.docker_image }} export DOCKER_DEVICES=${{ inputs.docker_devices }} - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -s -vv integration-tests diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 497468d9..a56edaca 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -455,6 +455,6 @@ class DeployedModel(BaseModel): # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members # with model_ prefixes, since this disables guardrails for colliding fields: # https://github.com/pydantic/pydantic/issues/9177 - model_config = ConfigDict(protected_namespaces=()) + model_config = ConfigDict(protected_namespaces=()) model_id: str sha: str diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index a53c8e3b..e74180e7 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": @@ -7,7 +7,7 @@ if SYSTEM == "cuda": from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING -elif IPEX_AVAIL: - from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "ipex": + from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/ipex.py similarity index 100% rename from server/text_generation_server/layers/attention/xpu.py rename to server/text_generation_server/layers/attention/ipex.py diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index 93e83dfa..ce5289f9 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -3,7 +3,6 @@ from torch import nn from accelerate import init_empty_weights from text_generation_server.utils.import_utils import ( SYSTEM, - IPEX_AVAIL, ) @@ -83,7 +82,7 @@ elif SYSTEM == "rocm": return super().forward(hidden_states), residual -elif IPEX_AVAIL: +elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex class FastLayerNorm(nn.LayerNorm): @@ -112,7 +111,7 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if IPEX_AVAIL: + if SYSTEM == "ipex": out = ipex.llm.functional.add_rms_norm( residual, hidden_states, diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 1892cf69..b14005e6 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,14 +2,14 @@ import os import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops -elif IPEX_AVAIL: +elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif IPEX_AVAIL: + elif SYSTEM == "ipex": ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 510dc2c6..038de258 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -3,9 +3,9 @@ from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM -if IPEX_AVAIL: +if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_gather_into_tensor( world_out, gather_input, group=self.process_group ) @@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_gather(world_output, output, group=self.process_group) else: torch.distributed.all_gather(world_output, output, group=self.process_group) @@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) @@ -257,7 +257,7 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 56292250..f81bfa10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,9 +20,9 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM -if not IPEX_AVAIL: +if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 6ea95411..2f7619af 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM -if not IPEX_AVAIL: +if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 633b066b..aa43107f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK @@ -768,12 +768,9 @@ class FlashCausalLM(Model): empty_cache() element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "xpu": - x = 1 - else: - x = BLOCK_SIZE // element_size + x = BLOCK_SIZE // element_size - if IPEX_AVAIL and SYSTEM == "cpu": + if SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 43f374e5..75c7203a 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -34,12 +34,12 @@ class FlashGPT2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e023c4e0..76c522e3 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -17,7 +17,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM class FlashLlama(FlashCausalLM): @@ -34,12 +34,12 @@ class FlashLlama(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0c570487..78a09cf5 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -16,7 +16,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -38,12 +38,12 @@ class BaseFlashMistral(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 69d47e57..9c82bf52 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -33,12 +33,12 @@ class FlashNeoXSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e087dcf1..e8087f23 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -34,12 +34,12 @@ class FlashRWSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 9626af60..83a6b92c 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,12 +37,12 @@ class FlashSantacoderSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 7d387563..36d63e86 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,7 +3,7 @@ import torch from datetime import timedelta from loguru import logger -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -69,7 +69,7 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - if IPEX_AVAIL: + if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex ipex.distributed.init_process_group( diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index a244417a..6d921721 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -37,7 +37,10 @@ def get_cpu_free_memory(device, memory_fraction): return free_memory -IPEX_AVAIL = is_ipex_available() +def noop(*args, **kwargs): + pass + + SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" @@ -49,17 +52,19 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory -elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available(): - SYSTEM = "xpu" - empty_cache = torch.xpu.empty_cache - synchronize = torch.xpu.synchronize - get_free_memory = get_xpu_free_memory +elif is_ipex_available(): + SYSTEM = "ipex" + if hasattr(torch, "xpu") and torch.xpu.is_available(): + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory + else: + empty_cache = noop + synchronize = noop + get_free_memory = get_cpu_free_memory else: SYSTEM = "cpu" - def noop(*args, **kwargs): - pass - empty_cache = noop synchronize = noop get_free_memory = get_cpu_free_memory From 27ae4f79168cad670a42fcf91cd2db6733029ce2 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 25 Jun 2024 22:47:06 +0800 Subject: [PATCH 072/340] fix cpu and xpu issue (#2116) Signed-off-by: Wang, Yi A --- server/text_generation_server/models/flash_causal_lm.py | 5 ++++- server/text_generation_server/models/flash_gpt2.py | 3 ++- server/text_generation_server/models/flash_llama.py | 3 ++- server/text_generation_server/models/flash_mistral.py | 3 ++- server/text_generation_server/models/flash_neox.py | 3 ++- server/text_generation_server/models/flash_rw.py | 3 ++- server/text_generation_server/models/flash_santacoder.py | 3 ++- 7 files changed, 16 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index aa43107f..7f8268a9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -768,7 +768,10 @@ class FlashCausalLM(Model): empty_cache() element_size = torch.tensor([], dtype=dtype).element_size() - x = BLOCK_SIZE // element_size + if SYSTEM == "ipex" and device.type == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size if SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 75c7203a..2d0f9fcc 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 76c522e3..9366706f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 78a09cf5..16778ada 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 9c82bf52..87ae570c 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e8087f23..6ed1f6f7 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 83a6b92c..ab1e4516 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") From 136fb7e9b9e980539b90bfd5b07aec9d5e9903d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 25 Jun 2024 16:53:20 +0200 Subject: [PATCH 073/340] Add pytest release marker (#2114) * Add pytest release marker Annotate a test with `@pytest.mark.release` and it only gets run with `pytest integration-tests --release`. * Mark many models as `release` to speed up CI --- integration-tests/conftest.py | 20 +++++++++++++++++++ integration-tests/models/test_bloom_560m.py | 3 +++ .../models/test_bloom_560m_sharded.py | 2 ++ .../models/test_completion_prompts.py | 3 +++ integration-tests/models/test_flash_awq.py | 3 +++ .../models/test_flash_awq_sharded.py | 2 ++ integration-tests/models/test_flash_falcon.py | 3 +++ integration-tests/models/test_flash_gemma.py | 3 +++ .../models/test_flash_gemma_gptq.py | 3 +++ integration-tests/models/test_flash_gpt2.py | 2 ++ .../models/test_flash_llama_exl2.py | 3 +++ .../models/test_flash_llama_gptq.py | 3 +++ .../models/test_flash_llama_gptq_marlin.py | 3 +++ .../models/test_flash_llama_marlin.py | 3 +++ integration-tests/models/test_flash_neox.py | 2 ++ .../models/test_flash_neox_sharded.py | 2 ++ .../models/test_flash_pali_gemma.py | 2 ++ integration-tests/models/test_flash_phi.py | 3 +++ integration-tests/models/test_flash_qwen2.py | 3 +++ .../models/test_flash_santacoder.py | 2 ++ .../models/test_flash_starcoder.py | 3 +++ .../models/test_flash_starcoder2.py | 3 +++ .../models/test_flash_starcoder_gptq.py | 3 +++ .../models/test_grammar_llama.py | 1 + .../test_grammar_response_format_llama.py | 2 ++ integration-tests/models/test_idefics.py | 2 ++ integration-tests/models/test_llava_next.py | 3 +++ integration-tests/models/test_mamba.py | 3 +++ integration-tests/models/test_mpt.py | 2 ++ integration-tests/models/test_mt0_base.py | 3 +++ integration-tests/models/test_neox.py | 2 ++ integration-tests/models/test_neox_sharded.py | 2 ++ integration-tests/models/test_t5_sharded.py | 2 ++ 33 files changed, 101 insertions(+) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 13337165..f5f38ac6 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -37,6 +37,26 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") +def pytest_addoption(parser): + parser.addoption( + "--release", action="store_true", default=False, help="run release tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "release: mark test as a release-only test") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--release"): + # --release given in cli: do not skip release tests + return + skip_release = pytest.mark.skip(reason="need --release option to run") + for item in items: + if "release" in item.keywords: + item.add_marker(skip_release) + + class ResponseComparator(JSONSnapshotExtension): rtol = 0.2 ignore_logprob = False diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index bdcbdc78..d413519e 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle): return bloom_560_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m(bloom_560, response_snapshot): response = await bloom_560.generate( @@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_all_params(bloom_560, response_snapshot): response = await bloom_560.generate( @@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index 3995f9e5..f9e8ed9c 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle): return bloom_560m_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): response = await bloom_560m_sharded.generate( @@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_sharded_load( bloom_560m_sharded, generate_load, response_snapshot diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index cafa8ea6..0efb6693 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle): # method for it. Instead, we use the `requests` library to make the HTTP request directly. +@pytest.mark.release def test_flash_llama_completion_single_prompt( flash_llama_completion, response_snapshot ): @@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt( assert response == response_snapshot +@pytest.mark.release def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): response = requests.post( f"{flash_llama_completion.base_url}/v1/completions", @@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn assert response == response_snapshot +@pytest.mark.release async def test_flash_llama_completion_many_prompts_stream( flash_llama_completion, response_snapshot ): diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index ead918c3..b500b15d 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle): return flash_llama_awq_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( @@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( @@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index a83614ac..4cf9b171 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): return flash_llama_awq_handle_sharded.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( @@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py index eac91984..0fb40fe7 100644 --- a/integration-tests/models/test_flash_falcon.py +++ b/integration-tests/models/test_flash_falcon.py @@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle): return flash_falcon_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_falcon(flash_falcon, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_falcon_all_params(flash_falcon, response_snapshot): @@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index 7ab43111..7bee8dea 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle): return flash_gemma_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma(flash_gemma, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_all_params(flash_gemma, response_snapshot): @@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py index 8ac5f5a1..79d4cf24 100644 --- a/integration-tests/models/test_flash_gemma_gptq.py +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): return flash_gemma_gptq_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_all_params( @@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params( assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_load( diff --git a/integration-tests/models/test_flash_gpt2.py b/integration-tests/models/test_flash_gpt2.py index 0c7977d0..cd73d0a3 100644 --- a/integration-tests/models/test_flash_gpt2.py +++ b/integration-tests/models/test_flash_gpt2.py @@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle): return flash_gpt2_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_gpt2(flash_gpt2, response_snapshot): response = await flash_gpt2.generate( @@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 18319f60..7169c999 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle): return flash_llama_exl2_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): @@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_all_params( @@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params( assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_load( diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index b87f054b..135f4b05 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle): return flash_llama_gptq_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): @@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_load( diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py index 9c37a644..2274abce 100644 --- a/integration-tests/models/test_flash_llama_gptq_marlin.py +++ b/integration-tests/models/test_flash_llama_gptq_marlin.py @@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): return flash_llama_gptq_marlin_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_marlin_all_params( @@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params( assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_marlin_load( diff --git a/integration-tests/models/test_flash_llama_marlin.py b/integration-tests/models/test_flash_llama_marlin.py index e7c5ccbd..a89a1e41 100644 --- a/integration-tests/models/test_flash_llama_marlin.py +++ b/integration-tests/models/test_flash_llama_marlin.py @@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle): return flash_llama_marlin_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot): @@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_marlin_load( diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index 0289c61d..31848dae 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle): return flash_neox_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox(flash_neox, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py index 8a491915..1f1e7225 100644 --- a/integration-tests/models/test_flash_neox_sharded.py +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle): return flash_neox_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_neox(flash_neox_sharded, response_snapshot): response = await flash_neox_sharded.generate( @@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 6be1750c..3ead3150 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -34,6 +34,7 @@ def get_cow_beach(): return f"data:image/png;base64,{encoded_string.decode('utf-8')}" +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): @@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 9d6ca566..73bb5edc 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle): return flash_phi_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_phi(flash_phi, response_snapshot): response = await flash_phi.generate( @@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_phi_all_params(flash_phi, response_snapshot): response = await flash_phi.generate( @@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_qwen2.py b/integration-tests/models/test_flash_qwen2.py index 2963aeb4..c64f8732 100644 --- a/integration-tests/models/test_flash_qwen2.py +++ b/integration-tests/models/test_flash_qwen2.py @@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle): return flash_qwen2_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_qwen2(flash_qwen2, response_snapshot): response = await flash_qwen2.generate( @@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): response = await flash_qwen2.generate( @@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index 0f005f15..96a36aba 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -13,6 +13,7 @@ async def flash_santacoder(flash_santacoder_handle): return flash_santacoder_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_santacoder(flash_santacoder, response_snapshot): response = await flash_santacoder.generate( @@ -23,6 +24,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_santacoder_load( flash_santacoder, generate_load, response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 64e8b27c..dc5a8a53 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle): return flash_starcoder_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder(flash_starcoder, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): @@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_starcoder2.py b/integration-tests/models/test_flash_starcoder2.py index ea665b6c..88341cfe 100644 --- a/integration-tests/models/test_flash_starcoder2.py +++ b/integration-tests/models/test_flash_starcoder2.py @@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle): return flash_starcoder2_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder2(flash_starcoder2, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): @@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder2_load( diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index 329158b7..f1007d6e 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -13,6 +13,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): return flash_starcoder_gptq_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): response = await flash_starcoder_gptq.generate( @@ -24,6 +25,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap assert response == generous_response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_starcoder_gptq_default_params( flash_starcoder_gptq, generous_response_snapshot @@ -40,6 +42,7 @@ async def test_flash_starcoder_gptq_default_params( assert response == generous_response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_starcoder_gptq_load( flash_starcoder_gptq, generate_load, generous_response_snapshot diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index ce5da8a9..4face9e1 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle): return non_flash_llama_grammar_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot): diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index 9c4c048e..ea25fa1c 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle): return llama_grammar_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): @@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh assert chat_completion == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_error_if_tools_not_installed( llama_grammar, diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index ac807b76..b7725f0b 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_idefics_two_images(idefics, response_snapshot): @@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_idefics_load(idefics, generate_load, response_snapshot): chicken = get_chicken() diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index f5b290b1..ea277d71 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -26,6 +26,7 @@ async def flash_llava_next(flash_llava_next_handle): return flash_llava_next_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): @@ -41,6 +42,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): @@ -64,6 +66,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_load( diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index bf3701b4..bc946de8 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -13,6 +13,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): return fused_kernel_mamba_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( @@ -24,6 +25,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( @@ -50,6 +52,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mamba_load( fused_kernel_mamba, generate_load, generous_response_snapshot diff --git a/integration-tests/models/test_mpt.py b/integration-tests/models/test_mpt.py index d58a8c5a..1832910a 100644 --- a/integration-tests/models/test_mpt.py +++ b/integration-tests/models/test_mpt.py @@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle): return mpt_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_mpt(mpt_sharded, response_snapshot): response = await mpt_sharded.generate( @@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mpt_load(mpt_sharded, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index c877056a..e53d8ed4 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -13,6 +13,7 @@ async def mt0_base(mt0_base_handle): return mt0_base_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_mt0_base(mt0_base, response_snapshot): response = await mt0_base.generate( @@ -27,6 +28,7 @@ async def test_mt0_base(mt0_base, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mt0_base_all_params(mt0_base, response_snapshot): response = await mt0_base.generate( @@ -49,6 +51,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mt0_base_load(mt0_base, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index 7b88f86a..ee60441d 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -15,6 +15,7 @@ async def neox(neox_handle): return neox_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox(neox, response_snapshot): @@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox_load(neox, generate_load, response_snapshot): diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index 8cee8765..a69227c9 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle): return neox_sharded_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox(neox_sharded, response_snapshot): @@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox_load(neox_sharded, generate_load, response_snapshot): diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py index 4b4cfd98..24003024 100644 --- a/integration-tests/models/test_t5_sharded.py +++ b/integration-tests/models/test_t5_sharded.py @@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle): return t5_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_t5_sharded(t5_sharded, response_snapshot): response = await t5_sharded.generate( @@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot): responses = await generate_load( From 8980bf43d7448e89f2dd35e120629570d5d387d4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jun 2024 17:53:36 +0200 Subject: [PATCH 074/340] Fix CI . (#2118) Fix clippy. --- benchmark/src/main.rs | 4 +++- launcher/src/main.rs | 3 +-- router/src/main.rs | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 603b4087..2ee3d7c5 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -147,7 +147,9 @@ fn main() -> Result<(), Box> { tracing::info!("Downloading tokenizer"); // Parse Huggingface hub token - let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok(); + let auth_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 90b587cc..e4dd1bb5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -766,7 +766,7 @@ fn num_cuda_devices() -> Option { Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { Ok(devices) => devices, Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, - } + }, }; let n_devices = devices.split(',').count(); Some(n_devices) @@ -1229,7 +1229,6 @@ fn spawn_webserver( router_args.push("--otlp-service-name".to_string()); router_args.push(otlp_service_name); - // CORS origins for origin in args.cors_allow_origin.into_iter() { router_args.push("--cors-allow-origin".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index dcb9ce99..a7caec2e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -159,7 +159,9 @@ async fn main() -> Result<(), RouterError> { }); // Parse Huggingface hub token - let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok(); + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); // Tokenizer instance // This will only be used to validate payloads From 8a155b2d5bb5deb7bb008feef66f6a197f930236 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 25 Jun 2024 14:46:27 -0400 Subject: [PATCH 075/340] Enable multiple LoRa adapters (#2010) * feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek --- benchmark/src/generation.rs | 1 + docs/source/_toctree.yml | 5 +- docs/source/basic_tutorials/launcher.md | 8 + docs/source/conceptual/lora.md | 65 +++ launcher/src/main.rs | 13 + proto/v3/generate.proto | 2 + router/client/src/v3/client.rs | 1 + router/client/src/v3/sharded_client.rs | 1 + router/src/infer/v2/queue.rs | 1 + router/src/infer/v3/queue.rs | 2 + router/src/lib.rs | 6 + router/src/server.rs | 2 + router/src/validation.rs | 3 + server/Makefile | 1 + server/Makefile-lorax-punica | 12 + server/tests/models/test_model.py | 7 +- .../adapters/__init__.py | 13 + .../text_generation_server/adapters/config.py | 44 ++ .../text_generation_server/adapters/lora.py | 482 ++++++++++++++++++ .../adapters/weights.py | 158 ++++++ server/text_generation_server/cli.py | 48 +- .../text_generation_server/layers/__init__.py | 6 + server/text_generation_server/layers/lora.py | 286 +++++++++++ .../text_generation_server/models/__init__.py | 4 +- server/text_generation_server/models/bloom.py | 1 + .../models/causal_lm.py | 1 + .../custom_modeling/flash_cohere_modeling.py | 1 + .../custom_modeling/flash_dbrx_modeling.py | 1 + .../custom_modeling/flash_gemma_modeling.py | 1 + .../custom_modeling/flash_gpt2_modeling.py | 1 + .../custom_modeling/flash_llama_modeling.py | 129 +++-- .../custom_modeling/flash_mistral_modeling.py | 93 +++- .../custom_modeling/flash_mixtral_modeling.py | 1 + .../custom_modeling/flash_neox_modeling.py | 1 + .../flash_pali_gemma_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 1 + .../custom_modeling/flash_qwen2_modeling.py | 1 + .../custom_modeling/flash_rw_modeling.py | 1 + .../flash_santacoder_modeling.py | 1 + .../flash_starcoder2_modeling.py | 1 + .../models/custom_modeling/idefics2.py | 1 + .../models/custom_modeling/llava_next.py | 1 + .../models/flash_causal_lm.py | 135 ++++- .../models/flash_cohere.py | 1 + .../models/flash_dbrx.py | 1 + .../models/flash_gemma.py | 1 + .../models/flash_gpt2.py | 1 + .../models/flash_llama.py | 83 ++- .../models/flash_mistral.py | 84 ++- .../models/flash_neox.py | 1 + .../models/flash_phi.py | 1 + .../models/flash_qwen2.py | 1 + .../text_generation_server/models/flash_rw.py | 1 + .../models/flash_santacoder.py | 1 + .../models/flash_starcoder2.py | 1 + .../models/galactica.py | 1 + .../text_generation_server/models/globals.py | 12 + .../text_generation_server/models/gpt_neox.py | 1 + .../text_generation_server/models/idefics.py | 1 + .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + server/text_generation_server/models/model.py | 157 +++++- server/text_generation_server/models/mpt.py | 1 + server/text_generation_server/models/opt.py | 1 + server/text_generation_server/models/phi.py | 1 + server/text_generation_server/models/rw.py | 1 + .../models/santacoder.py | 1 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/models/t5.py | 1 + .../models/vlm_causal_lm.py | 4 +- server/text_generation_server/server.py | 41 +- .../text_generation_server/utils/adapter.py | 196 +++++++ server/text_generation_server/utils/hub.py | 38 ++ .../utils/merges/strategies.py | 223 ++++++++ .../utils/merges/utils.py | 108 ++++ server/text_generation_server/utils/peft.py | 25 +- .../text_generation_server/utils/segments.py | 66 +++ server/text_generation_server/utils/sgmv.py | 248 +++++++++ 78 files changed, 2776 insertions(+), 75 deletions(-) create mode 100644 docs/source/conceptual/lora.md create mode 100644 server/Makefile-lorax-punica create mode 100644 server/text_generation_server/adapters/__init__.py create mode 100644 server/text_generation_server/adapters/config.py create mode 100644 server/text_generation_server/adapters/lora.py create mode 100644 server/text_generation_server/adapters/weights.py create mode 100644 server/text_generation_server/layers/lora.py create mode 100644 server/text_generation_server/utils/adapter.py create mode 100644 server/text_generation_server/utils/merges/strategies.py create mode 100644 server/text_generation_server/utils/merges/utils.py create mode 100644 server/text_generation_server/utils/segments.py create mode 100644 server/text_generation_server/utils/sgmv.py diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b82d23ba..5e739703 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + adapter_id: None, }) .collect(); diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7599562a..c9b4efd9 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -60,6 +60,9 @@ - local: conceptual/speculation title: Speculation (Medusa, ngram) - local: conceptual/guidance - title: How Guidance Works (via outlines) + title: How Guidance Works (via outlines + - local: conceptual/lora + title: LoRA (Low-Rank Adaptation) + title: Conceptual Guides diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index f6175925..5e40146f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -416,6 +416,14 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] +``` +## LORA_ADAPTERS +```shell + --lora-adapters + Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request + + [env: LORA_ADAPTERS=] + ``` ## HELP ```shell diff --git a/docs/source/conceptual/lora.md b/docs/source/conceptual/lora.md new file mode 100644 index 00000000..08df767c --- /dev/null +++ b/docs/source/conceptual/lora.md @@ -0,0 +1,65 @@ +# LoRA (Low-Rank Adaptation) + +## What is LoRA? + +LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task. + +LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed. + +## How is it used? + +LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA: + +Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as: + +- fine-tuning a language model on a small dataset +- fine-tuning a language model on a domain-specific dataset +- fine-tuning a language model on a dataset with limited labels + +## Optimizing Inference with LoRA + +LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models. + +## Serving multiple LoRA adapters with TGI + +Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned. + +In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset. + +Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library. + +### Specifying LoRA models + +To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example: + +```bash +LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia +``` + +In the server logs, you will see the following message: + +```txt +Loading adapter weights into model: predibase/customer_support +Loading adapter weights into model: predibase/dbpedia +``` + +## Generate text + +You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example: + +```json +curl 127.0.0.1:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "Hello who are you?", + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/customer_support" + } +}' +``` + +> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon. + +An updated tutorial with detailed examples will be published soon. Stay tuned! diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e4dd1bb5..e6928958 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -454,6 +454,11 @@ struct Args { /// Control the maximum number of inputs that a client can send in a single request #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + + /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during + /// startup that will be available to callers via the `adapter_id` field in a request. + #[clap(long, env)] + lora_adapters: Option, } #[derive(Debug)] @@ -487,6 +492,7 @@ fn shard_manager( max_total_tokens: usize, max_batch_size: Option, max_input_tokens: usize, + lora_adapters: Option, otlp_endpoint: Option, otlp_service_name: String, log_level: LevelFilter, @@ -623,6 +629,11 @@ fn shard_manager( envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); } + // Lora Adapters + if let Some(lora_adapters) = lora_adapters { + envs.push(("LORA_ADAPTERS".into(), lora_adapters.into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -1064,6 +1075,7 @@ fn spawn_shards( let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; let max_batch_size = args.max_batch_size; + let lora_adapters = args.lora_adapters.clone(); thread::spawn(move || { shard_manager( model_id, @@ -1089,6 +1101,7 @@ fn spawn_shards( max_total_tokens, max_batch_size, max_input_tokens, + lora_adapters, otlp_endpoint, otlp_service_name, max_log_level, diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd..926c878e 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -134,6 +134,8 @@ message Request { repeated uint32 blocks = 9; /// Paged attention slots repeated uint32 slots = 10; + /// LORA adapter index + optional string adapter_id = 11; } message Batch { diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb..a996b14f 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -177,6 +177,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, + adapter_id: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55..ae8a899b 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + adapter_id: None, }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 3725c03e..93cf9469 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -429,6 +429,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142a..ba65b9b6 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -351,6 +351,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -491,6 +492,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/lib.rs b/router/src/lib.rs index 5d201937..126726c6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub grammar: Option, + + /// Lora adapter id + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub adapter_id: Option, } fn default_max_new_tokens() -> Option { @@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters { seed: None, top_n_tokens: None, grammar: None, + adapter_id: None, } } diff --git a/router/src/server.rs b/router/src/server.rs index aa872df9..6e6b93b6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -673,6 +673,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, + ..Default::default() }, }) .collect(); @@ -1115,6 +1116,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, + ..Default::default() }, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index bb9ad318..e2bf5a5d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -202,6 +202,7 @@ impl Validation { decoder_input_details, top_n_tokens, grammar, + adapter_id, .. } = request.parameters; @@ -383,6 +384,7 @@ impl Validation { parameters, stopping_parameters, top_n_tokens, + adapter_id, }) } @@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, + pub adapter_id: Option, } #[derive(Error, Debug)] diff --git a/server/Makefile b/server/Makefile index 5257b876..0099c56a 100644 --- a/server/Makefile +++ b/server/Makefile @@ -4,6 +4,7 @@ include Makefile-vllm include Makefile-awq include Makefile-eetq include Makefile-selective-scan +include Makefile-lorax-punica unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica new file mode 100644 index 00000000..72f06f76 --- /dev/null +++ b/server/Makefile-lorax-punica @@ -0,0 +1,12 @@ +lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc + +build-lorax-punica: + if [ ! -d 'lorax-punica' ]; then \ + git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ + fi + cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) + cd lorax-punica && git submodule update --init --recursive + cd lorax-punica/server/punica_kernels && python setup.py build + +install-lorax-punica: build-lorax-punica + cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index 32bcd45f..8441e8c6 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -17,7 +17,12 @@ def get_test_model(): tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") model = TestModel( - torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") + "test_model_id", + torch.nn.Linear(1, 1), + tokenizer, + False, + torch.float32, + torch.device("cpu"), ) return model diff --git a/server/text_generation_server/adapters/__init__.py b/server/text_generation_server/adapters/__init__.py new file mode 100644 index 00000000..8697cb9e --- /dev/null +++ b/server/text_generation_server/adapters/__init__.py @@ -0,0 +1,13 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/__init__.py +# License: Apache License Version 2.0, January 2004 + +from text_generation_server.adapters.weights import ( + AdapterBatchData, + AdapterBatchMetadata, +) + +__all__ = [ + "AdapterBatchData", + "AdapterBatchMetadata", +] diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py new file mode 100644 index 00000000..5261d4b5 --- /dev/null +++ b/server/text_generation_server/adapters/config.py @@ -0,0 +1,44 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/config.py +# License: Apache License Version 2.0, January 2004 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple + +import torch + +from text_generation_server.adapters.weights import AdapterWeights + +if TYPE_CHECKING: + from text_generation_server.models.model import Model + + +@dataclass +class ModuleMap: + module_name: str + module_weights: Dict[str, Tuple[torch.Tensor, str]] + + +@dataclass +class AdapterConfig(ABC): + base_model_name_or_path: str + + @abstractmethod + def map_weights_for_model( + self, + adapter_weights: Dict[int, AdapterWeights], + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + pass + + @abstractmethod + def load_batched_adapter_weights( + self, + model: "Model", + module_map: ModuleMap, + layer_type: str, + unused_weight_names: Set[str], + dynamic: bool, + ) -> Optional[AdapterWeights]: + pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py new file mode 100644 index 00000000..87543be2 --- /dev/null +++ b/server/text_generation_server/adapters/lora.py @@ -0,0 +1,482 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/lora.py +# License: Apache License Version 2.0, January 2004 + +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +from peft import LoraConfig as _LoraConfig +from torch.distributed import ProcessGroup + +from text_generation_server.adapters.config import AdapterConfig, ModuleMap + +from text_generation_server.adapters.weights import ( + AdapterBatchMetadata, + AdapterWeights, + BatchAdapterWeights, +) +from text_generation_server.utils.sgmv import ( + BGMV_MAX_RANK, + MAX_RANK_CUSTOM, + get_tmp_tensors, + orient_for_rank, + pad_rank, + use_cutlass_shrink, +) + +if TYPE_CHECKING: + from text_generation_server.models.model import Model + + +def get_start_stop_idxs_for_rank(offset, size, rank, world_size): + block_size = size // world_size + start = offset + rank * block_size + stop = offset + (rank + 1) * block_size + return start, stop + + +def shard_on_dim( + t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup +): + world_size = process_group.size() + rank = process_group.rank() + + size = t.shape[dim] + start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) + + if dim == 0: + tensor = t[start:stop] + elif dim == 1: + tensor = t[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + + return tensor + + +def shard_lora_weights( + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + split_dim: int, + process_group: ProcessGroup, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a + ] + + # [r, hidden_size] + weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b] + + return weights_a, weights_b + + +@dataclass +class LoraConfig(AdapterConfig): + r: int + target_modules: Optional[Union[List[str], str]] + fan_in_fan_out: bool + lora_alpha: int + use_rslora: bool + + def map_weights_for_model( + self, + adapter_weights: Dict[int, AdapterWeights], + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + adapter_weight_names = set() + module_map = {} + for weight_name in weight_names: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: + continue + + module_map[weight_name] = { + "lora_A": (adapter_weights[lora_a_name], lora_a_name), + "lora_B": (adapter_weights[lora_b_name], lora_b_name), + } + adapter_weight_names.add(lora_a_name) + adapter_weight_names.add(lora_b_name) + return module_map, adapter_weight_names + + def load_batched_adapter_weights( + self, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + dynamic: bool, + ) -> Optional[AdapterWeights]: + return LoraWeights.load( + self, + model, + module_map, + layer_type, + unused_weight_names, + ) + + @classmethod + def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": + hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) + return cls( + base_model_name_or_path=hf_config.base_model_name_or_path, + r=hf_config.r, + target_modules=hf_config.target_modules, + fan_in_fan_out=hf_config.fan_in_fan_out, + lora_alpha=hf_config.lora_alpha, + use_rslora=( + hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False + ), + ) + + +class LoraWeights(AdapterWeights): + """LoRA weights for a single adapter merged across all layers.""" + + def __init__( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + adapter_config: LoraConfig, + ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + + self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._is_transposed = False + + # [num_layers, hidden_size, r] + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + self._weights_a = torch.stack(weights_a) + + # [num_layers, r, hidden_size] + self._weights_b = torch.stack(weights_b) + + self.adapter_config = adapter_config + + @property + def weights_a(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_b + + @property + def weights_a_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_b + + def _transpose_weights(self): + if self._use_cutlass_shrink: + # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation + self._weights_a = self._weights_a.transpose(1, 2).contiguous() + self._weights_b = self._weights_b.transpose(1, 2).contiguous() + self._is_transposed = not self._is_transposed + + @classmethod + def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: + return [BatchLoraWeights] + + @classmethod + def load( + cls, + config: LoraConfig, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + ) -> Optional[AdapterWeights]: + nlayers = model.get_num_layers_for_type(layer_type) + lora_a_list = [None] * nlayers + lora_b_list = [None] * nlayers + + for layer_id in range(nlayers): + key = (layer_id, layer_type) + weight_name, layer = model.target_to_layer[key] + base_weight = layer.base_layer.linear.weight + base_device = base_weight.device + + if weight_name not in module_map: + # There is no LoRA weight for this layer type in the adapter + return None + + lora_a, lora_a_name = module_map[weight_name]["lora_A"] + lora_a = lora_a.to(base_device, model.dtype) + + lora_b, lora_b_name = module_map[weight_name]["lora_B"] + lora_b = lora_b.to(base_device, model.dtype) + + scale = get_scaling_factor( + config.lora_alpha, + config.r, + uses_rslora=config.use_rslora, + ) + + unused_weight_names.discard(lora_a_name) + unused_weight_names.discard(lora_b_name) + + # Merge scaling factor into lora_b due to associativity of matrix multiplication: + # (A * B) * C = A * (B * C) + lora_a_list[layer_id] = lora_a.transpose(0, 1) + lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale + + # pad lora ranks to be compatible with sgmv + lora_a_list = [ + pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list + ] + lora_b_list = [ + pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list + ] + + if lora_a_list: + # update rank if it was padded + padded_rank = lora_a_list[0].size(1) + config.r = padded_rank + + return LoraWeights( + *shard_lora_weights( + weights_a=lora_a_list, + weights_b=lora_b_list, + split_dim=0 if model.is_row_parallel(layer_type) else 1, + process_group=model.process_group, + ), + config, + ) + + +@dataclass +class RankSegments: + rank: int + + lora_a_ptr: torch.Tensor + lora_b_ptr: torch.Tensor + + # prefill (sgmv) + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor + segment_starts: torch.Tensor + segment_ends: torch.Tensor + + # decode (bgmv) + indices: torch.Tensor + + +@dataclass +class BatchLoraWeights(BatchAdapterWeights): + lora_a: Dict[int, torch.Tensor] + lora_b: Dict[int, torch.Tensor] + adapter_index_configs: Dict[int, LoraConfig] + rank_data: Dict[int, RankSegments] + use_sgmv: bool + + def has_adapter(self, adapter_index: int) -> bool: + return adapter_index in self.adapter_index_configs + + def can_vectorize(self, pg: ProcessGroup) -> bool: + return all( + rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + for rank_data in self.rank_data.values() + ) + + @classmethod + def key(cls) -> str: + return "lora" + + @classmethod + def load( + self, + adapter_weights: Dict[int, AdapterWeights], + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Optional["BatchLoraWeights"]: + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} + adapter_weights = { + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) + } + if not adapter_weights: + return None + + first_weights = next(iter(adapter_weights.values())) + device = first_weights.weights_a.device + segment_indices = meta.segment_indices + + lora_a = { + idx: adapter_weights[idx].weights_a + for idx in segment_indices + if idx in adapter_weights + } + lora_b = { + idx: adapter_weights[idx].weights_b + for idx in segment_indices + if idx in adapter_weights + } + + max_rank = max( + ( + adapter_weights[idx].lora_a_r + for idx in segment_indices + if idx in adapter_weights + ), + default=0, + ) + + if prefill or max_rank > BGMV_MAX_RANK: + use_sgmv = True + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + else: + use_sgmv = False + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a_t.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b_t.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} + + rank_indices = defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx not in adapter_weights: + continue + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + + if prefill_head_indices is not None: + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + for head_index in prefill_head_indices: + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + if head_index < meta.adapter_segments[j]: + prefill_head_segment_ends[-1] += 1 + else: + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) + j += 1 + + rank_data = {} + for rank, indices in rank_indices.items(): + tmp_shrink = None + tmp_expand = None + segment_starts = None + segment_ends = None + batch_indices = None + + if use_sgmv: + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors( + lora_a_ptr_indices.size(0), rank, device + ) + segment_starts = meta.adapter_segments[indices] + segment_ends = meta.adapter_segments[[i + 1 for i in indices]] + if prefill_head_indices is not None: + for i, segment_index in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_index] + segment_ends[i] = prefill_head_segment_ends[segment_index] + else: + rank_indices = set(indices) + batch_indices = [ + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() + ] + batch_indices = [ + idx if idx in rank_indices else -1 for idx in batch_indices + ] + batch_indices = torch.tensor( + batch_indices, dtype=torch.int64, device=device + ) + + rank_data[rank] = RankSegments( + rank=rank, + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=lora_a_ptr[indices], + lora_b_ptr=lora_b_ptr[indices], + segment_starts=segment_starts, + segment_ends=segment_ends, + indices=batch_indices, + ) + + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + use_sgmv=use_sgmv, + ) + + +def get_scaling_factor( + lora_alpha: int, + r: int, + uses_rslora: bool = False, +) -> float: + """Computes the scaling factor for the lora weights.""" + if uses_rslora: + return lora_alpha / (r**0.5) + return lora_alpha / r + + +def _convert_lora(v: AdapterWeights) -> AdapterWeights: + if hasattr(v, "lora_weights"): + return v.lora_weights + return v diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py new file mode 100644 index 00000000..8f658756 --- /dev/null +++ b/server/text_generation_server/adapters/weights.py @@ -0,0 +1,158 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/weights.py +# License: Apache License Version 2.0, January 2004 + +from abc import ABC, abstractclassmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Type + +import torch + + +@dataclass +class AdapterBatchMetadata: + # [batch_size] + adapter_indices: torch.Tensor + + # [num_adapters] + adapter_set: Set[int] + + # [num_segments + 1] + adapter_segments: torch.Tensor + + # [num_segments] + # maps from segment index to adapter index, i.e.: + # segment_indices[s] == adapter_indices[i] + segment_indices: List[int] + + +class AdapterWeights(ABC): + @abstractclassmethod + def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]: + pass + + @property + def speculative_tokens(self) -> int: + return 0 + + +class BatchAdapterWeights(ABC): + @abstractclassmethod + def has_adapter(self, adapter_index: int) -> bool: + pass + + @abstractclassmethod + def key(cls) -> str: + pass + + @abstractclassmethod + def load( + cls, + adapter_weights: Dict[int, AdapterWeights], + meta: "AdapterBatchMetadata", + prefill: bool, + prefill_head_indices: torch.Tensor, + ) -> Optional["BatchAdapterWeights"]: + pass + + +class LayerAdapterWeights: + """Adapter weights that apply to a particular layer.""" + + def __init__(self): + self.adapter_weights: Dict[int, AdapterWeights] = {} + + def add_adapter(self, adapter_idx: int, weights: AdapterWeights): + self.adapter_weights[adapter_idx] = weights + + def remove_adapter(self, adapter_idx: int): + if adapter_idx not in self.adapter_weights: + return + del self.adapter_weights[adapter_idx] + + @property + def max_speculative_tokens(self) -> int: + return max( + adapter_weights.speculative_tokens + for adapter_weights in self.adapter_weights.values() + ) + + def is_empty(self) -> bool: + return len(self.adapter_weights) == 0 + + def get_data( + self, + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Dict[str, BatchAdapterWeights]: + # bucket adapters by batch class + adapter_batch_types: Dict[ + Type[BatchAdapterWeights], Dict[int, AdapterWeights] + ] = defaultdict(dict) + for adapter_index, adapter_weights in self.adapter_weights.items(): + for batch_type in adapter_weights.get_batch_types(): + adapter_batch_types[batch_type][adapter_index] = adapter_weights + + batch_data = {} + for batch_type, adapter_weights in adapter_batch_types.items(): + batched_weights = batch_type.load( + adapter_weights, meta, prefill, prefill_head_indices + ) + if batched_weights is not None: + batch_data[batch_type.key()] = batched_weights + return batch_data + + +@dataclass +class AdapterBatchData: + meta: AdapterBatchMetadata + + # layer type -> adapter type -> batch weight data + data: Dict[str, Dict[str, BatchAdapterWeights]] + + prefill: bool + + @staticmethod + def from_meta( + meta: AdapterBatchMetadata, + weights: Dict[str, LayerAdapterWeights], + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> "AdapterBatchData": + data = {} + for k, v in weights.items(): + if v.is_empty(): + continue + data[k] = v.get_data( + meta, prefill, prefill_head_indices if k == "lm_head" else None + ) + return AdapterBatchData(meta=meta, data=data, prefill=prefill) + + def ranks(self) -> Set[int]: + # TODO(travis): refactor to be less coupled to lora implementation + ranks = set() + for layer_data in self.data.values(): + lora_data = layer_data.get("lora") + if lora_data is None: + continue + + for rank_data in lora_data.rank_data.values(): + ranks.add(rank_data.rank) + + return ranks + + def layer_names(self) -> Set[str]: + return set(self.data.keys()) + + def adapter_keys(self) -> Set[str]: + adapter_keys = set() + for layer_data in self.data.values(): + adapter_keys.update(layer_data.keys()) + return adapter_keys + + @property + def max_rank(self) -> int: + ranks = self.ranks() + return max(ranks) if len(ranks) > 0 else 0 diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 18cad071..68ae95dd 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -79,6 +79,18 @@ def serve( if otlp_endpoint is not None: setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) + lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) + + # split on comma and strip whitespace + lora_adapter_ids = ( + [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] + ) + + if len(lora_adapter_ids) > 0: + logger.warning( + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + ) + # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value @@ -93,6 +105,7 @@ def serve( ) server.serve( model_id, + lora_adapter_ids, revision, sharded, quantize, @@ -113,6 +126,7 @@ def download_weights( logger_level: str = "INFO", json_output: bool = False, trust_remote_code: bool = False, + merge_lora: bool = False, ): # Remove default handler logger.remove() @@ -143,18 +157,28 @@ def download_weights( ) is not None if not is_local_model: - try: - adapter_config_filename = hf_hub_download( - model_id, revision=revision, filename="adapter_config.json" - ) - utils.download_and_unload_peft( - model_id, revision, trust_remote_code=trust_remote_code - ) - is_local_model = True - utils.weight_files(model_id, revision, extension) - return - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass + # TODO: maybe reverse the default value of merge_lora? + # currently by default we don't merge the weights with the base model + if merge_lora: + try: + adapter_config_filename = hf_hub_download( + model_id, revision=revision, filename="adapter_config.json" + ) + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + is_local_model = True + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + else: + try: + utils.peft.download_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + except Exception: + pass try: import json diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index c29dd092..32c8d121 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d + +from text_generation_server.layers.lora import ( + LoraLinear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py new file mode 100644 index 00000000..0bb6db41 --- /dev/null +++ b/server/text_generation_server/layers/lora.py @@ -0,0 +1,286 @@ +import math +import os +from typing import TYPE_CHECKING, Optional, Tuple, List + +import torch +import torch.distributed +from accelerate import init_empty_weights +from torch import nn +from torch.nn import functional as F +from torch.distributed import ProcessGroup + +from text_generation_server.utils.sgmv import ( + add_lora_a_bgmv, + add_lora_b_bgmv, + has_sgmv, + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + orient_for_rank, +) + +if TYPE_CHECKING: + from text_generation_server.adapters import AdapterBatchData + from text_generation_server.adapters.lora import BatchLoraWeights + + +class LoraLinear(nn.Module): + def __init__( + self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup + ): + super().__init__() + self.base_layer = base_layer + self.layer_id = layer_id + self.process_group = process_group + + def forward_layer_type( + self, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + ) -> torch.Tensor: + if adapter_data is None: + return result + data = adapter_data.data.get(layer_type) + data: Optional["BatchLoraWeights"] = ( + data.get("lora") if data is not None else None + ) + + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + # In tensor-parallel configurations, each GPU processes a specific segment of the output. + # The 'result' tensor represents the full output, which can vary in size based on + # the layer type (e.g., attention vs. feed-forward layers). We define the current + # segment using start_idx and end_idx. If the segment size doesn't match this GPU's + # slice of 'result', we create a zero tensor of the correct size for LoRA computation. + # This approach ensures accurate LoRA application across various layer sizes and + # configurations, adapting to different model architectures and parallelization strategies. + # + # Example scenarios where this is necessary: + # 1. The adapter's size doesn't evenly divide across GPUs. + # 2. We're processing the last segment which might be smaller. + # 3. Different projection layers (q, k, v) have different sizes. + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if lora_a_ptr is None or lora_b_ptr is None: + raise ValueError("LoRA data is missing") + + if data.use_sgmv: + # Use SGMV for prefill + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + # TODO: error with [-1, 0], but not [0, -1] + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) + layer_result = self.forward_lora( + input, data, adapter_index, adapter_mask + ) + result[:, start_idx:end_idx] += layer_result + + return result + + def forward_lora( + self, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + ) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][self.layer_id, :, :] + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = input @ lora_a + if self.process_group.size() > 1: + a_out = self.collect_lora_a(a_out) + + result = (a_out @ lora_b) * adapter_mask + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Implemented in subclasses") + + +class TensorParallelMultiAdapterLinear(LoraLinear): + def __init__( + self, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): + super().__init__(base_layer, layer_id, process_group) + self.layer_names = layer_names + self.sizes = sizes + + @classmethod + def load( + cls, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): + return TensorParallelMultiAdapterLinear( + base_layer, layer_id, layer_names, sizes, process_group + ) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + # noop if no layer names are provided (e.g. for models without adapters) + if self.layer_names is None: + return result + + # handle models like Bloom that have inputs of shape + # (batch_size, sequence_length, hidden_size) + # we need to reshape them to (batch_size * sequence_length, hidden_size) + # for the LoRA computation, then reshape back + prev_shape = result.shape + is_3d = len(input.shape) >= 3 + if is_3d: + input = input.reshape(-1, input.shape[-1]) + result = result.reshape(-1, result.shape[-1]) + + offset = 0 + for i, layer_name in enumerate(self.layer_names): + start_idx = offset // self.process_group.size() + # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple + # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It + # ensures correct slicing of the result tensor, accommodating variations like grouped-query + # attention where k_proj and v_proj differ from q_proj. This allows precise application of + # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the + # different projection sizes across layers and model architectures. + if self.sizes is not None: + offset += self.sizes[i] + end_idx = offset // self.process_group.size() + else: + end_idx = result.shape[1] + + result = self.forward_layer_type( + result, input, adapter_data, layer_name, start_idx, end_idx + ) + + if is_3d: + result = result.reshape(prev_shape) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise. + # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-gather for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + gathered_tensors = [ + torch.empty_like(a_out) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(gathered_tensors, a_out) + return torch.cat(gathered_tensors, dim=1) + + +class TensorParallelAdapterRowLinear(LoraLinear): + def __init__(self, base_layer, layer_id, layer_name, process_group): + super().__init__(base_layer, layer_id, process_group) + self.layer_name = layer_name + + @classmethod + def load(cls, base_layer, layer_id, layer_name, process_group): + return cls(base_layer, layer_id, layer_name, process_group) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + if self.layer_name is None: + return result + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = result.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + self.forward_layer_type( + result, input, adapter_data, self.layer_name, start_idx, end_idx + ) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + torch.distributed.all_reduce(a_out, group=self.process_group) + return a_out diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 76dca3dc..648fcee9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,7 +6,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional +from typing import Optional, List from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -253,6 +253,7 @@ for data in ModelType: def get_model( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -595,6 +596,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 38006502..17aa12e8 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index e896c831..10c64c66 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -538,6 +538,7 @@ class CausalLM(Model): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 6d315ba5..2850a6f3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index f81bfa10..9d56e4ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 04d05cd6..a4fd4740 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0c01f56a..7e7510c7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: token_embeds = self.embed_tokens(input_ids) position_embeds = self.embed_positions(position_ids) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0d06d104..c48ed268 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -51,43 +53,61 @@ if SYSTEM == "rocm": raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) + head_size = config.hidden_size // config.num_attention_heads + sizes = None + prefixes = None - # if specific model type, load the correct attention if config.model_type == "phi3": - return TensorParallelColumnLinear.load_qkv( + prefix = f"{prefix}.qkv_proj" + base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=f"{prefix}.qkv_proj", + prefix=prefix, weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( + prefix = f"{prefix}.W_pack" + base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=f"{prefix}.W_pack", + prefix=prefix, weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + else: + prefixes = ["q_proj", "k_proj", "v_proj"] + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] + base_layer = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) - # otherwise, load the default attention based on the number of heads - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=bias, + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, ) class FlashLlamaAttention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -121,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) + self.index = index - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -145,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module): slots, input_lengths, max_s, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -190,11 +220,13 @@ class FlashLlamaAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class LlamaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -209,29 +241,54 @@ class LlamaMLP(nn.Module): ), ) ) + prefixes = None + sizes = None + # Fuse gate and up proj bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": - self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( + gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, bias=bias, ) else: - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"gate_proj", f"up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=bias, ) - self.down_proj = TensorParallelRowLinear.load( + + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=bias, ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) + self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -239,7 +296,7 @@ class LlamaMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize - def forward(self, hidden_states): + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" @@ -253,20 +310,27 @@ class LlamaMLP(nn.Module): device="cuda", ) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out) + return self.down_proj(out, adapter_data) else: - gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() self.self_attn = FlashLlamaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) - self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -289,6 +353,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -303,6 +368,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + adapter_data, ) # faster post attention rms norm @@ -310,7 +376,7 @@ class FlashLlamaLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -325,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers = nn.ModuleList( [ FlashLlamaLayer( + index=layer_id, prefix=( f"model.layers.{layer_id}" if not prefix @@ -360,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -382,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module): slots, input_lengths, max_s, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -423,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -436,6 +506,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 77a8a384..d1ba5564 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, get_linear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig): class MistralAttention(torch.nn.Module): - def __init__( - self, - prefix: str, - config, - weights, - ): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.max_past = ( config.sliding_window if config.sliding_window is not None else -1 @@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = TensorParallelColumnLinear.load_multi( + query_key_value = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, @@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module): bias=False, ) - self.o_proj = TensorParallelRowLinear.load( + head_size = config.hidden_size // config.num_attention_heads + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -244,19 +263,37 @@ class MistralMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -264,7 +301,7 @@ class MistralMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize - def forward(self, hidden_states): + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" @@ -278,20 +315,27 @@ class MistralMLP(nn.Module): device="cuda", ) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out) + return self.down_proj(out, adapter_data) else: - gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = MistralMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -315,6 +359,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -330,6 +375,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -337,7 +383,7 @@ class MistralLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module): prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, + layer_id=layer_id, ) for layer_id in range(config.num_hidden_layers) ] @@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: @@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2f7619af..2e839d15 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d399be2f..b87fd4ca 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 91c709e4..1f998e5a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): # Unused here pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6dda4b2b..3f445f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index df5a8ae9..69f38c3a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 7d3c72a7..04d4ba51 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2ae0908c..badfc367 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -483,6 +483,7 @@ class FlashSantacoderForCausalLM(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index c3e2e099..f6a2e15d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 51fd7c02..a83bc1c6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module): pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index de9673aa..9a670140 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module): # Unused for this model pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, ): inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7f8268a9..f7678762 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -13,6 +13,7 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Iterable, Optional, Tuple, List, Type, Dict +from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM @@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS import text_generation_server.models.globals as tgi_globals from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( empty_cache, @@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch): top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor + # Adapter metadata for each request + adapter_meta: AdapterBatchMetadata + # Number of blocks in this batch num_blocks: int # Maximum number of blocks @@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] + adapter_indices_list = [] + adapter_set = set() + # Cumulative length cumulative_length = 0 cumulative_max_length = 0 @@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) + adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append(torch.full((input_length,), adapter_index)) + adapter_set.add(adapter_index) + # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() @@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch): max_length, input_length + max_new_tokens + speculative_length ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) @@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch): input_lengths, dtype=torch.int32, device=device ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), speculative_ids=None, ) @@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] + adapter_set = set() num_blocks = 0 max_blocks = 0 @@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) + adapter_index = tgi_globals.ADAPTER_TO_INDEX.get( + self.requests[idx].adapter_id, 0 + ) + adapter_set.add(adapter_index) + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch): # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] @@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch): # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) @classmethod @@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_set = set() + adapter_segment_builder = SegmentConcatBuilder() start_slots = [] block_tables = [] @@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 cumulative_slots = 0 + cumulative_adapter_indices_size = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) + all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch): else None ) + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) def __len__(self): @@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, num_layers: int, @@ -738,6 +816,7 @@ class FlashCausalLM(Model): self.kv_cache = [] super(FlashCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, @@ -895,12 +974,13 @@ class FlashCausalLM(Model): total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size free_memory = get_free_memory(self.device, MEMORY_FRACTION) + batch_num_blocks = batch.num_blocks if batch is not None else 0 num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. - + batch.num_blocks + + batch_num_blocks ) del batch @@ -1001,7 +1081,7 @@ class FlashCausalLM(Model): ) def forward( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: @@ -1080,6 +1160,7 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, + adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -1116,7 +1197,34 @@ class FlashCausalLM(Model): prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - out, speculative_logits = self.forward(batch) + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) + + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + + out, speculative_logits = self.forward(batch, adapter_data) if prefill: next_token_logits = ( @@ -1128,8 +1236,13 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + else: next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices speculate = get_speculate() ( @@ -1195,6 +1308,12 @@ class FlashCausalLM(Model): # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] + # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: @@ -1220,6 +1339,16 @@ class FlashCausalLM(Model): batch.position_ids = next_position_ids + accepted_ids batch.input_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) if prefill and prefill_logprobs: # Get prefill logprobs diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 1077d78e..9f8bcb3f 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashCohere, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index ffb6d5a6..2aba6a00 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashDbrx, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 1b7b2772..aa1ae9ac 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 2d0f9fcc..323fcafa 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM): model = FlashGPT2ForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashGPT2, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 9366706f..d996b9c3 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -1,9 +1,10 @@ +import os import torch import torch.distributed from opentelemetry import trace from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional +from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( @@ -13,12 +14,24 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + hub, ) tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import SYSTEM +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + class FlashLlama(FlashCausalLM): def __init__( @@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM): speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM): rank=rank, world_size=world_size, ) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 16778ada..209eca83 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -3,7 +3,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.flash_causal_lm import set_sliding_window @@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + class BaseFlashMistral(FlashCausalLM): def __init__( self, @@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM): torch.distributed.barrier(group=self.process_group) num_layers, num_kv_heads, head_size = self.get_layer_config(model) super().__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=num_layers, @@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM): model.model.head_size, ) + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL + class FlashMistral(BaseFlashMistral): def __init__( diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 87ae570c..ac1fd573 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashNeoXSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.gpt_neox.layers), diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 0cc67cec..7e108d05 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashPhi, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 9fcfce9d..23528f0b 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral): torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 6ed1f6f7..b1f75adc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashRWSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index ab1e4516..e1a7b36e 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashSantacoderSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 1ac731be..369e9e4c 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral): torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f39bd1e9..30c92d90 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 970a673b..cc2f172a 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,6 +1,7 @@ import torch import os from loguru import logger +from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli @@ -32,3 +33,14 @@ MODEL_ID = None def set_model_id(model_id: str): global MODEL_ID MODEL_ID = model_id + + +# NOTE: eventually we should move this into the router and pass back the +# index in all cases. +global ADAPTER_TO_INDEX +ADAPTER_TO_INDEX: Dict[str, int] = None + + +def set_adapter_to_index(adapter_to_index: Dict[str, int]): + global ADAPTER_TO_INDEX + ADAPTER_TO_INDEX = adapter_to_index diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 8d2cb0e1..c37cfb7d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index c1fe03e4..f2955bd0 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM): torch.distributed.barrier(group=self.process_group) super(IdeficsCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f507d669..6c562980 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -634,6 +634,7 @@ class IdeficsCausalLM(Model): tokenizer.add_special_tokens({"pad_token": ""}) super(IdeficsCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 3133a137..9189b45c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -453,6 +453,7 @@ class Mamba(Model): model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 4f35b0aa..c90fd38a 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -2,12 +2,24 @@ import inspect import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type +from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict +from collections import defaultdict from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse +from text_generation_server.adapters.weights import LayerAdapterWeights +from text_generation_server.utils.adapter import ( + load_and_merge_adapters, + AdapterParameters, + AdapterSource, +) +from loguru import logger + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + B = TypeVar("B", bound=Batch) @@ -15,6 +27,7 @@ B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, requires_padding: bool, @@ -24,7 +37,9 @@ class Model(ABC): world_size: int = 1, sliding_window: Optional[int] = None, speculate: Optional[int] = None, + adapter_id: str = BASE_MODEL_ADAPTER_ID, ): + self.model_id = model_id self.model = model.eval() self.tokenizer = tokenizer @@ -42,6 +57,13 @@ class Model(ABC): self.world_size = world_size self.sliding_window = sliding_window if sliding_window != -1 else None + self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( + LayerAdapterWeights + ) + self.target_to_layer = self.adapter_target_to_layer() + self.loaded_adapters = set() + self.static_adapter_id = adapter_id + if speculate is None: speculate = get_speculate() self.speculate = speculate @@ -119,3 +141,136 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) + + @property + def supports_adapter_loading(self) -> bool: + return False + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + return {} + + @property + def adapter_layers(self) -> List[str]: + return [] + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 0 + + def is_row_parallel(self, layer_type: str) -> bool: + return False + + @property + def max_speculative_tokens(self) -> int: + return max( + [ + weights.max_speculative_tokens + for weights in self.layer_to_adapter_weights.values() + ], + default=0, + ) + + def load_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + api_token: str, + dynamic: bool = True, + ): + """Loads adapter weights from disk / host memory on the GPU. + + adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded + into model. Otherwise, the adapter weights are applied during the forward + pass and stored separately from the base model parameters. + """ + if adapter_index in self.loaded_adapters: + # Adapter already loaded + return + + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if dynamic and not self.dynamic_adapter_loading_enabled: + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) + + logger.info( + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + ) + weight_names = tuple([v[0] for v in self.target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + self.model_id, + adapter_parameters, + adapter_source, + adapter_index, + weight_names, + api_token, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + for layer_name in self.adapter_layers: + adapter_weights = adapter_config.load_batched_adapter_weights( + self, + module_map, + layer_name, + unused_weight_names, + dynamic, + ) + + if adapter_weights is None: + continue + + layer_weights = self.layer_to_adapter_weights[layer_name] + layer_weights.add_adapter(adapter_index, adapter_weights) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + self.loaded_adapters.add(adapter_index) + + def offload_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + ): + """Offloads the adapter weights from GPU to CPU or disk.""" + if adapter_index not in self.loaded_adapters: + # Adapter already offloaded + return + + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if not self.dynamic_adapter_loading_enabled: + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) + + for layer_name in self.adapter_layers: + if layer_name in self.layer_to_adapter_weights: + self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) + + self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 65180e73..1e79b25f 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -90,6 +90,7 @@ class MPTSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 1f4fbfcd..6d7d07f5 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -63,6 +63,7 @@ class OPTSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index d68866c1..93d42b2b 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -60,6 +60,7 @@ class Phi(CausalLM): model = PhiForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 50f6ead8..37ca277b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -62,6 +62,7 @@ class RW(CausalLM): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 323e4324..caddbe19 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -62,6 +62,7 @@ class SantaCoder(CausalLM): ) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3bd09556..d454d804 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -575,6 +575,7 @@ class Seq2SeqLM(Model): tokenizer.bos_token_id = model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 8e0735e5..adef664c 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 8b5819d1..218d1167 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -222,7 +222,9 @@ class VlmCausalLM(BaseFlashMistral): return VlmCausalLMBatch def forward( - self, batch: VlmCausalLMBatch + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a0347cd8..aee287c6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -29,7 +29,10 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, set_adapter_to_index +from text_generation_server.utils.adapter import ( + AdapterParameters, +) class SignalHandler: @@ -192,6 +195,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -203,6 +207,7 @@ def serve( ): async def serve_inner( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -211,6 +216,7 @@ def serve( trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" + adapter_to_index = {} if sharded: server_urls = [ unix_socket_template.format(uds_path, rank) @@ -224,6 +230,7 @@ def serve( try: model = get_model( model_id, + lora_adapter_ids, revision, sharded, quantize, @@ -232,10 +239,33 @@ def serve( trust_remote_code, max_input_tokens, ) + + if len(lora_adapter_ids) > 0: + for index, adapter_id in enumerate(lora_adapter_ids): + # TODO: improve non merged adapter loading and long term + # improve adapter loading as a whole + adapter_parameters = AdapterParameters( + adapter_ids=[adapter_id], + weights=None, # will be set to 1 + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + adapter_index = index + 1 + adapter_to_index[adapter_id] = adapter_index + model.load_adapter( + adapter_parameters, + None, # adapter_source + adapter_index, + None, # api_token + False, # dynamic + ) + except Exception: logger.exception("Error when initializing model") raise + set_adapter_to_index(adapter_to_index) server = aio.server( interceptors=[ ExceptionInterceptor(), @@ -266,6 +296,13 @@ def serve( set_model_id(model_id) asyncio.run( serve_inner( - model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, ) ) diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py new file mode 100644 index 00000000..4e2492de --- /dev/null +++ b/server/text_generation_server/utils/adapter.py @@ -0,0 +1,196 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/adapter.py +# License: Apache License Version 2.0, January 2004 + +import warnings +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, Set, Tuple + +from safetensors.torch import load_file +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer + +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils.merges.strategies import merge_adapters + +from text_generation_server.utils import hub +from text_generation_server.adapters.lora import LoraConfig + + +if TYPE_CHECKING: + from text_generation_server.adapters.config import AdapterConfig, ModuleMap + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + + +@dataclass +class AdapterParameters: + adapter_ids: Tuple[str] + weights: Tuple[float] + merge_strategy: NotImplemented + density: float + majority_sign_method: NotImplemented + + +@dataclass +class AdapterSource: + adapter_id: str + model_id: str + revision: str + + +def load_and_merge_adapters( + model_id: str, + adapter_parameters: AdapterParameters, + adapter_source: str, + adapter_index: int, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + if len(adapter_parameters.adapter_ids) == 1: + return load_module_map( + model_id, + adapter_parameters.adapter_ids[0], + adapter_source, + weight_names, + api_token, + trust_remote_code, + ) + + adapter_params = AdapterParametersContainer( + adapter_parameters, adapter_source, adapter_index + ) + return _load_and_merge( + model_id, adapter_params, weight_names, api_token, trust_remote_code + ) + + +@dataclass +class AdapterParametersContainer: + adapter_parameters: AdapterParameters + adapter_source: str + adapter_index: int + + def __hash__(self) -> int: + return self.adapter_index + + +@lru_cache(maxsize=32) +def _load_and_merge( + model_id: str, + adapter_params: AdapterParametersContainer, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + params = adapter_params.adapter_parameters + + adapters_to_merge = [] + merged_weight_names = set() + tokenizer = None + for adapter_id in params.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + raise ValueError("Base model adapter cannot be merged.") + + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( + load_module_map( + model_id, + adapter_id, + adapter_params.adapter_source, + weight_names, + api_token, + trust_remote_code, + ) + ) + + adapters_to_merge.append((module_map, adapter_config)) + merged_weight_names = merged_weight_names.union(adapter_weight_names) + if tokenizer is None: + tokenizer = adapter_tokenizer + + if len(adapters_to_merge) == 0: + raise ValueError("No adapters to merge.") + + module_map, adapter_config = merge_adapters(adapters_to_merge, params) + return module_map, adapter_config, merged_weight_names, tokenizer + + +def check_architectures( + model_id: str, + adapter_id: str, + adapter_config: "AdapterConfig", + trust_remote_code: bool = False, +): + try: + if not adapter_config.base_model_name_or_path: + # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None) + return + + expected_config = AutoConfig.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + model_config = AutoConfig.from_pretrained( + adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code + ) + except Exception as e: + warnings.warn( + f"Unable to check architecture compatibility for adapter '{adapter_id}' " + f"against model '{model_id}'. Assuming they are compatible. Error: {e}" + ) + return + + if model_config.architectures == expected_config.architectures: + warnings.warn( + f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " + f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + else: + # TODO(travis): revisit this when we support clasification heads which will not use CausalLM + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + + +@lru_cache(maxsize=128) +def load_module_map( + model_id: str, + adapter_id: str, + adapter_source: str, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + revision = "main" + + adapter_config = LoraConfig.load(adapter_id, api_token) + if adapter_config.base_model_name_or_path != model_id: + check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) + + adapter_filenames = hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) + + try: + adapter_tokenizer = AutoTokenizer.from_pretrained( + adapter_config.config_path, + token=api_token, + trust_remote_code=trust_remote_code, + ) + except Exception: + # Adapter does not have a tokenizer, so fallback to base model tokenizer + adapter_tokenizer = None + + # load adapter weights from all shards (should have relatively small memory footprint) + adapter_weights = {} + for filename in adapter_filenames: + adapter_weights.update(load_file(filename)) + + # map the model weights to the relevant adapter weights (LoRA A and B matrices) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, weight_names + ) + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index b56484f6..db412aeb 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] +def _cached_adapter_weight_files( + adapter_id: str, revision: Optional[str], extension: str +) -> List[str]: + """Guess weight files from the cached revision snapshot directory""" + d = _get_cached_revision_directory(adapter_id, revision) + if not d: + return [] + filenames = _adapter_weight_files_from_dir(d, extension) + return filenames + + def _cached_weight_files( model_id: str, revision: Optional[str], extension: str ) -> List[str]: @@ -60,6 +71,33 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: return filenames +def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_files_from_dir, that's also what is done there + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "training" not in f + ] + return filenames + + +def _adapter_config_files_from_dir(d: Path) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_files_from_dir, that's also what is done there + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(".json") and "arguments" not in f and "args" not in f + ] + return filenames + + def _get_cached_revision_directory( model_id: str, revision: Optional[str] ) -> Optional[Path]: diff --git a/server/text_generation_server/utils/merges/strategies.py b/server/text_generation_server/utils/merges/strategies.py new file mode 100644 index 00000000..3b885313 --- /dev/null +++ b/server/text_generation_server/utils/merges/strategies.py @@ -0,0 +1,223 @@ +import copy +from abc import ABC +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union + +import torch + + +class AdapterParameters: + def __init__( + self, adapter_ids, weights, merge_strategy, density, majority_sign_method + ): + self.adapter_ids = adapter_ids + self.weights = weights + self.merge_strategy = merge_strategy + self.density = density + self.majority_sign_method = majority_sign_method + + +from text_generation_server.utils.merges.utils import ( + calculate_majority_sign_mask, + disjoint_merge, + prune, +) + +if TYPE_CHECKING: + from text_generation_server.adapters.lora import LoraConfig + from text_generation_server.utils.adapter import ModuleMap + + +def _apply_weights( + tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor +) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + t = tensors + else: + t = torch.stack(tensors, dim=0) + + # element-wise weighting of each task tensor + # need to unsqueeze weights to match task tensor dimensions + # for multiplication to apply element-wise + while len(t.shape) > len(w.shape): + w = w.unsqueeze(-1) + return t * w + + +class MergeStrategy(ABC): + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() + + +class LinearMerge(MergeStrategy): + def __init__(self, **kwargs): + pass + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class TiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="magnitude") for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + return disjoint_merge(weighted_task_tensors, majority_sign_mask) + + +class DareLinearMerge(MergeStrategy): + def __init__(self, density: float, **kwargs): + self.density = density + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class DareTiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +strategy_registry: Dict[str, Type[MergeStrategy]] = { + "linear": LinearMerge, + "ties": TiesMerge, + "dare_linear": DareLinearMerge, + "dare_ties": DareTiesMerge, +} + + +def merge_adapters( + adapters: List[Tuple["ModuleMap", "LoraConfig"]], + merge_params: AdapterParameters, +) -> Tuple["ModuleMap", "LoraConfig"]: + # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() + strategy_name = "linear" + + weights = merge_params.weights + if not weights: + weights = torch.ones(len(adapters)) + else: + weights = torch.tensor(weights) + + merge_config = { + "density": merge_params.density, + # "majority_sign_method": MajoritySignMethodEnum.Name( + # merge_params.majority_sign_method + # ).lower(), + "majority_sign_method": "total", + } + merge_strategy = strategy_registry[strategy_name](**merge_config) + + module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) + lora_configs = [] + weight_name_to_adapter_idx = defaultdict(list) + + # input is list of (module_map, lora_config) tuples + # convert into dict[k][param_name] -> list of tensors + for idx, (module_map, lora_config) in enumerate(adapters): + for weight_name, data in module_map.items(): + weight_name_to_adapter_idx[weight_name].append(idx) + for k, (param_data, param_name) in data.items(): + module_maps[weight_name][k][param_name].append(param_data) + lora_configs.append(lora_config) + + # validate lora configs are compatible + _validate_lora_configs(lora_configs) + + # merge tensors for each module such that we have a single ModuleMap: + # dict[k] -> merged tensor + merged_module_map: "ModuleMap" = defaultdict(dict) + for weight_name, data in module_maps.items(): + indices = weight_name_to_adapter_idx[weight_name] + param_weights = weights[indices] + for k, param_data in data.items(): + for param_name, tensors in param_data.items(): + merged_tensor = merge_strategy.merge(tensors, param_weights) + merged_module_map[weight_name][k] = (merged_tensor, param_name) + + # merge lora configs + merged_lora_config = _merge_lora_configs(lora_configs) + + return merged_module_map, merged_lora_config + + +def _validate_lora_configs(lora_configs: List["LoraConfig"]): + # check that all configs have the same rank + ranks = set(lora_config.r for lora_config in lora_configs) + if len(ranks) > 1: + raise ValueError( + f"unable to merge adapters, lora configs have different ranks: {ranks}" + ) + + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): + raise ValueError( + "unable to merge adapters, lora configs have no target modules" + ) + + +def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": + merged_lora_config = copy.copy(lora_configs[0]) + + # merge target modules as a union operation + merged_target_modules = sorted( + set( + module + for lora_config in lora_configs + for module in lora_config.target_modules + ) + ) + merged_lora_config.target_modules = merged_target_modules + + return merged_lora_config diff --git a/server/text_generation_server/utils/merges/utils.py b/server/text_generation_server/utils/merges/utils.py new file mode 100644 index 00000000..d9ad3278 --- /dev/null +++ b/server/text_generation_server/utils/merges/utils.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# From: https://github.com/huggingface/peft/pull/1364 +# Copyright 2024-present the HuggingFace Inc. team. +# Modifications by Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.reshape(-1).shape[0]) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, + density: float, + method: Literal["magnitude", "random"], + rescale: bool = False, +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + if density >= 1: + return tensor + elif density < 0: + raise ValueError("Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask( + tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" +): + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = (sign * tensor.abs()).sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors, majority_sign_mask): + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 48ca264b..0ea89267 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -1,5 +1,5 @@ import os -import json +from typing import Union from loguru import logger import torch @@ -43,3 +43,26 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): model.save_pretrained(cache_dir, safe_serialization=True) model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir) + + +def download_peft( + model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool +): + torch_dtype = torch.float16 + try: + _model = AutoPeftModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + except Exception: + _model = AutoPeftModelForSeq2SeqLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + logger.info("Peft model downloaded.") diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py new file mode 100644 index 00000000..f5961102 --- /dev/null +++ b/server/text_generation_server/utils/segments.py @@ -0,0 +1,66 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/segments.py +# License: Apache License Version 2.0, January 2004 + +from typing import List, Tuple, Union + +import torch + + +def find_segments( + adapter_indices: Union[torch.Tensor, List[int]] +) -> Tuple[List[int], List[int]]: + segments = [0] + segment_indices = [] + + if isinstance(adapter_indices, torch.Tensor): + # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first + adapter_indices = adapter_indices.cpu().tolist() + + start_index = 0 + for i in range(1, len(adapter_indices)): + if adapter_indices[i] != adapter_indices[i - 1]: + segments.append(i) + segment_indices.append(adapter_indices[i - 1]) + start_index = i + + # Handle the last segment + if start_index < len(adapter_indices): + segments.append(len(adapter_indices)) + segment_indices.append(adapter_indices[-1]) + + return segments, segment_indices + + +class SegmentConcatBuilder: + def __init__(self): + self.adapter_segment_indices = [] + self.adapter_segment_tensors = [] + + def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): + # Update adapter segments + if self.adapter_segment_tensors: + # Because we have already processed at least one batch, remove the 0 start index + # from this batch denoting the beginning of the segment, then offset all segment + # positions by the value of the last segment in the previous batch to account for + # the concatenation. + adapter_segments = ( + adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] + ) + + if ( + self.adapter_segment_indices + and self.adapter_segment_indices[-1] == segment_indices[0] + ): + # If the last segment in the previous batch is the same as the first segment in this batch, + # then we merge them together into a single segment. In effect, this means removing it from + # the segment indices of this batch, and extending the segment span by removing the segment + # end index from the previous batch. + segment_indices = segment_indices[1:] + self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] + + self.adapter_segment_indices.extend(segment_indices) + self.adapter_segment_tensors.append(adapter_segments) + + def build(self) -> Tuple[torch.Tensor, List[int]]: + return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py new file mode 100644 index 00000000..e0aec25f --- /dev/null +++ b/server/text_generation_server/utils/sgmv.py @@ -0,0 +1,248 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/sgmv.py +# License: Apache License Version 2.0, January 2004 + +import os +import warnings +from functools import lru_cache +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +try: + import punica_kernels as _kernels + + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) +except ImportError: + warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") + _kernels = None + HAS_SGMV = False + + +MIN_SGMV_RANK = 8 +MIN_RANK_CUSTOM = 16 +MAX_RANK_CUSTOM = 128 +SGMV_BLOCK_SIZE = 16 +BGMV_MAX_RANK = 64 + + +def has_sgmv() -> bool: + return HAS_SGMV + + +def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: + """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" + if not has_sgmv(): + return t + + # tensor parallelism will result in effective rank being divided by world_size, + # so we need to scale the min rank to offset that effect + min_rank = MIN_SGMV_RANK * world_size + + # if we're at or below the min rank, pad up to the min rank + # otherwise, pad to the nearest multiple of the block size + current_rank = t.size(dim) + target_rank = ( + min_rank + if current_rank <= min_rank + else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE + ) + if current_rank == target_rank: + return t + + pad_size = target_rank - current_rank + + # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad = [0, 0] * t.dim() + pad[(t.dim() - dim - 1) * 2 + 1] = pad_size + pad = tuple(pad) + + return F.pad(t, pad, mode="constant", value=0.0) + + +def use_cutlass_shrink(lora_rank: int) -> bool: + return lora_rank < MIN_RANK_CUSTOM + + +def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: + if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: + return t.transpose(0, 1) + return t + + +# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py +def add_lora_sgmv_cutlass( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.Tensor, + s_end: torch.Tensor, + layer_idx: int, + lora_rank: int, +): + """ + Semantics: + y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H1]`. + wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H2]`. + s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. + s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. + layer_idx: Layer index of the weight matrices. + """ + if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: + # Custom SGMV shrink only supports rank 16, 32, 64, 128 + _add_lora_sgmv_cutlass_legacy( + y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank + ) + return + + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) + tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) + + +def _add_lora_sgmv_cutlass_legacy( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +): + tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +@lru_cache(maxsize=1) +def get_tmp_tensor(device: torch.device) -> torch.Tensor: + return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) + + +@lru_cache(maxsize=32) +def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: + tmp_size = _kernels.sgmv_cutlass_tmp_size(size) + return torch.empty((tmp_size,), dtype=torch.uint8, device=device) + + +def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: + return torch.empty((size,), dtype=torch.uint8, device=device) + + +def get_tmp_expand_size(size: int) -> int: + return _kernels.sgmv_cutlass_tmp_size(size) + + +def get_tmp_tensors( + nsegments: int, lora_rank: int, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + if use_cutlass_shrink(lora_rank) and has_sgmv(): + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + else: + tmp_shrink = get_tmp_tensor(device) + tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device) + return tmp_shrink, tmp_expand + + +def lora_a_sgmv_cutlass( + x: torch.Tensor, + tmp: torch.Tensor, + wa_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +) -> torch.Tensor: + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + else: + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + return v + + +def lora_b_sgmv_cutlass( + y: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, +): + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +""" +Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + +Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + v: Shape: `[B, R]`. Temporary vector. + x: Shape: `[B, H1]`. Input vectors. + wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. + wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. +""" + + +def add_lora_a_bgmv( + v: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) + + +def add_lora_b_bgmv( + y: torch.Tensor, + v: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) + + +def segmented_matmul( + y: torch.Tensor, + x: torch.Tensor, + w: List[torch.Tensor], + b: List[torch.Tensor], + s_start: torch.IntTensor, + s_end: torch.IntTensor, +): + for i in range(len(w)): + if s_end[i] - s_start[i] <= 0: + continue + + xi = x[s_start[i] : s_end[i]] + wi = w[i] + bi = b[i] + y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) From 18a8364d94ad92ae64063955abf251d23551e692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 25 Jun 2024 21:09:00 +0200 Subject: [PATCH 076/340] Support AWQ quantization with bias (#2117) When the AWQ quantizer was used with a layer that uses a bias, the bias tensor was not correctly passed/used. Instead, the value `true`/`1.0` was added to the linear transformation. Correctly pass through the bias when it is not `None`. Fixes #2106. --- .../layers/awq/quantize/qmodule.py | 10 +++++----- server/text_generation_server/layers/linear.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index ca8caf50..c859db1b 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -1,6 +1,7 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py import math +from typing import Optional import torch import torch.nn as nn import awq_inference_engine # with CUDA kernels @@ -17,7 +18,9 @@ import awq_inference_engine # with CUDA kernels class WQLinear(nn.Module): - def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): super().__init__() if w_bit not in [4]: @@ -35,10 +38,7 @@ class WQLinear(nn.Module): self.qweight = qweight self.qzeros = qzeros self.scales = scales - if bias: - self.bias = bias - else: - self.bias = None + self.bias = bias @torch.no_grad() def forward(self, x): diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index d40b192f..207383a5 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -217,7 +217,7 @@ def get_linear(weight, bias, quantize): qweight=weight.qweight, qzeros=weight.qzeros, scales=weight.scales, - bias=bias is not None, + bias=bias, ) except ImportError: raise NotImplementedError( From 4700ea413f11a2bdd13930eea61853995cf2e4ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 25 Jun 2024 21:09:42 +0200 Subject: [PATCH 077/340] Add support for Marlin 2:4 sparsity (#2102) This change adds support for 2:4 sparsity when using Marlin quantization. The 2:4 kernel is used when: * The quantizer is `marlin`; * the quantizer checkpoint format is `marlin_24`. Fixes #2098. --- server/marlin/marlin_kernels/__init__.pyi | 17 + server/marlin/marlin_kernels/ext.cpp | 1 + server/marlin/marlin_kernels/ext.hh | 7 + .../marlin_kernels/sparse/common/base.h | 51 + .../marlin/marlin_kernels/sparse/common/mem.h | 136 ++ .../marlin/marlin_kernels/sparse/common/mma.h | 191 +++ .../sparse/marlin_24_cuda_kernel.cu | 1125 +++++++++++++++++ server/marlin/setup.py | 1 + .../text_generation_server/layers/linear.py | 9 +- .../text_generation_server/layers/marlin.py | 129 +- .../text_generation_server/utils/weights.py | 80 +- 11 files changed, 1731 insertions(+), 16 deletions(-) create mode 100644 server/marlin/marlin_kernels/sparse/common/base.h create mode 100644 server/marlin/marlin_kernels/sparse/common/mem.h create mode 100644 server/marlin/marlin_kernels/sparse/common/mma.h create mode 100644 server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi index 73597f0c..663984d0 100644 --- a/server/marlin/marlin_kernels/__init__.pyi +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -19,6 +19,23 @@ def gptq_marlin_gemm( """ ... +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + """ + Matrix multiplication using Marlin kernels. This is an extension of + `marlin_gemm` that supports 2:4 sparsity. + """ + ... + def gptq_marlin_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 5855714d..37eccef6 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -5,6 +5,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Marlin gemm with GPTQ compatibility"); + m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); m.def("gptq_marlin_repack", &gptq_marlin_repack, "Repack GPTQ parameters for Marlin"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index 9ea01a3f..d1caaab7 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -12,6 +12,13 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full); +torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, int64_t size_k, int64_t size_n, int64_t num_bits); diff --git a/server/marlin/marlin_kernels/sparse/common/base.h b/server/marlin/marlin_kernels/sparse/common/base.h new file mode 100644 index 00000000..16018d33 --- /dev/null +++ b/server/marlin/marlin_kernels/sparse/common/base.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace marlin_24 { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +template +struct ShapeBase { + static constexpr int M = M_, N = N_, K = K_; +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragM = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mem.h b/server/marlin/marlin_kernels/sparse/common/mem.h new file mode 100644 index 00000000..83e3578d --- /dev/null +++ b/server/marlin/marlin_kernels/sparse/common/mem.h @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred_zfill(void* smem_ptr, + const void* glob_ptr, + bool pred = true, + const bool zfill = false) { + const int BYTES = 16; + int src_in_bytes = (zfill ? 0 : BYTES); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); +} + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_m); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} +} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mma.h b/server/marlin/marlin_kernels/sparse/common/mma.h new file mode 100644 index 00000000..b26505f7 --- /dev/null +++ b/server/marlin/marlin_kernels/sparse/common/mma.h @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" +#include + +namespace marlin_24 { + +// On CUDA earlier than 12.5, the ordered_metadata version of this instruction +// is not supported. On later versions of CUDA the version without ordered +// metadata results in the following warning: +// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction +// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially +// | reduced performance on some future architectures +#if defined CUDA_VERSION && CUDA_VERSION >= 12050 + #define MMA_SP_INST \ + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " +#else + #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " +#endif + +// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, + const FragA& frag_b, FragC& frag_c, FragM& frag_m, + const int psel) { + const uint32_t* a0 = reinterpret_cast(&a_frag0); + const uint32_t* a1 = reinterpret_cast(&a_frag1); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* e = reinterpret_cast(&frag_m); + + float* c = reinterpret_cast(&frag_c); + if (psel == 0) { + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } else { + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, + float c3) { + uint2 r; + asm("{\n\t" + ".reg .f16 a, b, c, d; \n\t" + "cvt.rn.f16.f32 a, %2; \n\t" + "cvt.rn.f16.f32 b, %3; \n\t" + "cvt.rn.f16.f32 c, %4; \n\t" + "cvt.rn.f16.f32 d, %5; \n\t" + "mov.b32 %0, {a, b}; \n\t" + "mov.b32 %1, {c, d}; \n\t" + "}" + : "=r"(r.x), "=r"(r.y) + : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); + return r; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, + FragS& s0, float* c4, float* c5, float* c6, + float* c7, FragS& s1) { + *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); + *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); + *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); + *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); + + *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); + *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); + *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); + *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); +} + +} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu b/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu new file mode 100644 index 00000000..b5effc30 --- /dev/null +++ b/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu @@ -0,0 +1,1125 @@ +/* + * Notice: This file was modified by Neuralmagic inc to include 8-bit support + * + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include "common/base.h" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +#else + + #include "common/mem.h" + #include "common/mma.h" + +#endif + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_24 { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int THREADS = 256; +static constexpr int STAGES = 4; + +static constexpr int min_thread_n = 128; + +static constexpr int tile_size = 16; +static constexpr int max_par = 64; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_24( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_24( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + // number of thread_k_blocks in k-dim + int k_tiles = prob_k / 32 / thread_k_blocks; + // number of thread_n_blocks in n-dim + int n_tiles = prob_n / 16 / thread_n_blocks; + // iters needed to cover all slices + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + // number of threadblock tiles in the current slice + int slice_iters; + // total number of active threadblocks in the current slice + int slice_count = 0; + // index of threadblock in current slice; numbered bottom to top + int slice_idx; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 32 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + constexpr int pack_factor = 32 / num_bits; + + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4* meta_ptr[m_sh_iters]; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2][b_thread_vecs]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + int4* sh_meta_stage = sh_m + m_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + + // Load meta with ldsm4 + int4* sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j2 = 0; j2 < 2; j2++) { + #pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j2 = 0; j2 < 2; j2++) { + #pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, + float c4, float c5, float c6, float c7, FragS& s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2* tmp = (half2*)&res; + // for per-column quantization we finally apply the scale here + if constexpr (group_blocks == -1 && num_bits == 4) { + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4*)sh)[idx] = *((int4*)&res[0]); + }; + + // RLC: only warp 0 and 1 baseline example + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], + frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], + frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], + frag_c[i][3][0][3], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], + frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], + frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], + frag_c[i][3][1][2], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], + frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], + frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], + frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); + + c_sh_wr += 8 * c_sh_stride_2; + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + matmul(pipe); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + + pipe++; + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + } + } + thread_block_reduce(); + + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + } + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], + &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], + &frag_c[i][0][0][2], &frag_c[i][1][0][2], + &frag_c[i][2][0][2], &frag_c[i][3][0][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], + &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], + &frag_c[i][0][0][3], &frag_c[i][1][0][3], + &frag_c[i][2][0][3], &frag_c[i][3][0][3], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], + &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], + &frag_c[i][0][1][2], &frag_c[i][1][1][2], + &frag_c[i][2][1][2], &frag_c[i][3][1][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], + &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], + &frag_c[i][0][1][3], &frag_c[i][1][1][3], + &frag_c[i][2][1][3], &frag_c[i][3][1][3], + frag_s[0][2]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#endif + +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ + } + +void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, + void* s, int prob_m, int prob_n, int prob_k, + void* workspace, int num_bits, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16) { + int tot_n = prob_n; + int tot_n_blocks = ceildiv(tot_n, 16); + int pad = 16 * tot_n_blocks - tot_n; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + TORCH_CHECK(sms > 0); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (thread_k == -1 || thread_m == -1) { + if (prob_n <= 16) { + // For small batchizes, better partitioningif is slightly more important + // than better compute utilization + thread_k = 128; + thread_m = 128; + } else if (prob_n <= 256) { + thread_k = 64; + thread_m = 256; + } else { + thread_k = 32; + thread_m = 512; + } + } + + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_m_blocks = thread_m / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, + " is not divisible by thread_m = ", thread_m); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, + " is not divisible by group_blocks = ", group_blocks); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + const int4* meta_ptr = (const int4*)meta; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + constexpr int max_m_blocks = 4; + + int* locks = (int*)workspace; + for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { + int thread_n_blocks = tot_n_blocks - i; + prob_n = tot_n - 16 * i; + int par = 1; + if (thread_n_blocks > max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); + if (par > max_par) par = max_par; + prob_n = (max_m_blocks * 16) * par; + i += max_m_blocks * (par - 1); + thread_n_blocks = max_m_blocks; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + + // the false is start of the CALL_IF macros + if (false) { + } // BMxBNxBK, group + // 4-bit + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 16, 2, 2, 4) + CALL_IF_2_4(4, 16, 3, 2, -1) + CALL_IF_2_4(4, 16, 3, 2, 4) + CALL_IF_2_4(4, 16, 4, 2, -1) + CALL_IF_2_4(4, 16, 4, 2, 4) + + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 2, 1, 4) + CALL_IF_2_4(4, 32, 3, 1, -1) + CALL_IF_2_4(4, 32, 3, 1, 4) + CALL_IF_2_4(4, 32, 4, 1, -1) + CALL_IF_2_4(4, 32, 4, 1, 4) + + // 8-bit + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 16, 2, 2, 4) + CALL_IF_2_4(8, 16, 3, 2, -1) + CALL_IF_2_4(8, 16, 3, 2, 4) + CALL_IF_2_4(8, 16, 4, 2, -1) + CALL_IF_2_4(8, 16, 4, 2, 4) + + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 2, 1, 4) + CALL_IF_2_4(8, 32, 3, 1, -1) + CALL_IF_2_4(8, 32, 3, 1, 4) + CALL_IF_2_4(8, 32, 4, 1, -1) + CALL_IF_2_4(8, 32, 4, 1, 4) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; + } +} + +} // namespace marlin_24 + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % marlin_24::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_24::tile_size)); + TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(marlin_24::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_24::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_24::tile_size)); + + int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); + + // Verify meta + TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, + "b_meta.size(0) = ", b_meta.size(0), + " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); + TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), + " is not size_n * 2 = ", size_n * 2); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify b_meta device and strides + TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); + TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + int thread_k = -1; + int thread_m = -1; + int sms = -1; + int max_par = marlin_24::max_par; + + int groupsize = -1; + if (b_scales.size(0) > 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + groupsize = size_k / b_scales.size(0); + groupsize /= 2; // Because of 24 + } + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 64, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_24::min_thread_n)); + int min_workspace_size = + (size_n / marlin_24::min_thread_n) * marlin_24::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + marlin_24::marlin_cuda_2_4( + a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), + num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_m, sms, max_par); + + return c; +} diff --git a/server/marlin/setup.py b/server/marlin/setup.py index 844e1139..aed84e9e 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -12,6 +12,7 @@ setup( "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/marlin_cuda_kernel.cu", + "marlin_kernels/sparse/marlin_24_cuda_kernel.cu", "marlin_kernels/ext.cpp", ], extra_compile_args=extra_compile_args, diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 207383a5..dd48465f 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,7 +1,6 @@ from typing import Optional import torch from torch.nn import functional as F -from text_generation_server.layers.marlin import GPTQMarlinLinear from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": @@ -225,6 +224,9 @@ def get_linear(weight, bias, quantize): ) elif quantize == "marlin": from text_generation_server.layers.marlin import ( + GPTQMarlin24Linear, + GPTQMarlin24Weight, + GPTQMarlinLinear, GPTQMarlinWeight, MarlinLinear, MarlinWeight, @@ -235,6 +237,11 @@ def get_linear(weight, bias, quantize): weight=weight, bias=bias, ) + elif isinstance(weight, GPTQMarlin24Weight): + linear = GPTQMarlin24Linear( + weight=weight, + bias=bias, + ) elif isinstance(weight, MarlinWeight): linear = MarlinLinear(weight=weight, bias=bias) else: diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 4d4c635e..2207b2e4 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,9 +1,8 @@ from dataclasses import dataclass -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import torch import torch.nn as nn - from text_generation_server.utils.import_utils import SYSTEM try: @@ -177,12 +176,12 @@ class GPTQMarlinLinear(nn.Module): self.bits = weight.bits self.is_full_k = weight.is_full_k - self.register_buffer("qweight", weight.qweight) - self.register_buffer("scales", weight.scales) - self.register_buffer("g_idx", weight.g_idx) - self.register_buffer("perm", weight.perm) + self.qweight = weight.qweight + self.scales = weight.scales + self.g_idx = weight.g_idx + self.perm = weight.perm if bias is not None: - self.register_buffer("bias", bias) + self.bias = bias else: self.bias = None @@ -215,6 +214,116 @@ class GPTQMarlinLinear(nn.Module): return C +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + + +@dataclass +class GPTQMarlin24Weight: + """ + GPTQ-Marlin 2:4 weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + B_meta (torch.Tensor): metadata for 2:4 sparsity. + s (torch.Tensor): float16 scales. + bits: quantized weight size. + """ + + B: torch.Tensor + B_meta: torch.Tensor + s: torch.Tensor + bits: int + + def __post_init__(self): + assert self.B.dtype == torch.int32 + assert self.B_meta.dtype == torch.int16 + assert self.s.dtype == torch.float16 + + +class GPTQMarlin24Linear(nn.Module): + def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + if weight.bits not in GPTQ_MARLIN_BITS: + supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) + raise RuntimeError( + f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" + ) + + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 + out_features = weight.s.shape[1] + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + + if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + supported_sizes = ", ".join( + str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ) + raise RuntimeError( + f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" + ) + + self.bits = weight.bits + weights_per_int32 = 32 // self.bits + + assert ( + out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 + ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" + assert ( + out_features % weights_per_int32 == 0 + ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" + + assert ( + in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 + ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" + if groupsize != -1 and in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisable by group size ({groupsize})" + ) + + self.B = weight.B + self.B_meta = weight.B_meta + self.s = weight.s + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, + dtype=torch.int, + device=weight.B.device, + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + C = marlin_kernels.gptq_marlin_24_gemm( + A.view(-1, A.shape[-1]), + self.B, + self.B_meta, + self.s, + self.workspace, + self.bits, + A.shape[0], + self.s.shape[1], + A.shape[1], + ) + + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + @dataclass class MarlinWeight: """ @@ -255,10 +364,10 @@ class MarlinLinear(nn.Module): 128, }, f"Group size must be -1 or 128, was {groupsize}" - self.register_buffer("B", weight.B) - self.register_buffer("s", weight.s) + self.B = weight.B + self.s = weight.s if bias is not None: - self.register_buffer("bias", bias) + self.bias = bias else: self.bias = None diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index e6142525..348d215c 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -13,6 +13,7 @@ from text_generation_server.utils.log import log_once @dataclass class _GPTQParams: bits: int + checkpoint_format: Optional[str] groupsize: int desc_act: bool quant_method: str @@ -263,12 +264,29 @@ class Weights: ) elif quantize == "marlin": from text_generation_server.layers.marlin import ( + GPTQMarlin24Weight, MarlinWeight, repack_gptq_for_marlin, ) quant_method = getattr(self, "quant_method", "marlin") - if quant_method == "gptq": + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + B = self.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = self.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = self.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + gptq_params = self._get_gptq_params() + weight = GPTQMarlin24Weight( + B=B, B_meta=B_meta, s=s, bits=gptq_params.bits + ) + elif quant_method == "gptq": gptq_params = self._get_gptq_params() try: qweight = self.get_packed_sharded( @@ -293,7 +311,6 @@ class Weights: sym=gptq_params.sym, sharded_infeatures=False, ) - else: B = self.get_packed_sharded( f"{prefix}.B", dim=1, block_sizes=block_sizes @@ -406,12 +423,36 @@ class Weights: elif quantize == "marlin": from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.marlin import ( + GPTQMarlin24Weight, MarlinWeight, repack_gptq_for_marlin, ) quant_method = getattr(self, "quant_method", "marlin") - if quant_method == "gptq": + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = torch.cat( + [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + gptq_params = self._get_gptq_params() + weight = GPTQMarlin24Weight( + B=B, B_meta=B_meta, s=s, bits=gptq_params.bits + ) + elif quant_method == "gptq": gptq_params = self._get_gptq_params() try: qweight = torch.cat( @@ -629,12 +670,35 @@ class Weights: elif quantize == "marlin": from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.marlin import ( + GPTQMarlin24Weight, MarlinWeight, repack_gptq_for_marlin, ) quant_method = getattr(self, "quant_method", "marlin") - if quant_method == "gptq": + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = self.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = self.get_tensor(f"{prefix}.s") + else: + s = self.get_sharded(f"{prefix}.s", dim=0) + + gptq_params = self._get_gptq_params() + weight = GPTQMarlin24Weight( + B=B, B_meta=B_meta, s=s, bits=gptq_params.bits + ) + elif quant_method == "gptq": log_once(logger.info, "Converting GPTQ model to Marlin packing format.") gptq_params = self._get_gptq_params() @@ -668,7 +732,7 @@ class Weights: B = self.get_sharded(f"{prefix}.B", dim=0) except RuntimeError: raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + "Cannot load `marlin` weight, make sure the model is already quantized." ) num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] @@ -688,6 +752,7 @@ class Weights: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() + checkpoint_format = getattr(self, "gptq_checkpoint_format", None) desc_act = False sym = True quant_method = "gptq" @@ -695,6 +760,7 @@ class Weights: try: bits = self.gptq_bits groupsize = self.gptq_groupsize + checkpoint_format = getattr(self, "gptq_checkpoint_format", None) desc_act = getattr(self, "gptq_desc_act", False) quant_method = getattr(self, "quant_method", "gptq") sym = getattr(self, "sym", True) @@ -703,6 +769,7 @@ class Weights: return _GPTQParams( bits=bits, + checkpoint_format=checkpoint_format, desc_act=desc_act, groupsize=groupsize, quant_method=quant_method, @@ -724,6 +791,9 @@ class Weights: self.gptq_groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models self.quant_method = data["quantization_config"]["quant_method"] + self.gptq_checkpoint_format = data["quantization_config"].get( + "checkpoint_format" + ) self.gptq_sym = data["quantization_config"]["sym"] self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: From 399919d71504b5160396ac893f43a71e87d0c6e2 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 25 Jun 2024 19:30:10 -0400 Subject: [PATCH 078/340] fix: simplify kserve endpoint and fix imports (#2119) --- router/src/kserve.rs | 66 +++++++++++++++++++++----------------------- router/src/server.rs | 2 +- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/router/src/kserve.rs b/router/src/kserve.rs index b64efd38..c53fa481 100644 --- a/router/src/kserve.rs +++ b/router/src/kserve.rs @@ -1,15 +1,15 @@ +use crate::infer::Infer; use crate::{ default_parameters, server::{generate_internal, ComputeType}, - Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema, + Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema, }; use axum::extract::{Extension, Path}; -use axum::response::{IntoResponse, Response}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::IntoResponse; use axum::Json; use futures::stream::FuturesUnordered; use futures::TryStreamExt; -use reqwest::header::HeaderMap; -use reqwest::StatusCode; #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct OutputChunk { @@ -64,8 +64,6 @@ pub struct MetadataServerResponse { pub extensions: Vec, } -// Routes - #[utoipa::path( post, tag = "Text Generation Inference", @@ -76,13 +74,13 @@ pub struct MetadataServerResponse { example = json!({"error": "No response"})) ) )] -pub async fn kserve_health_live() -> Result)> { +pub async fn kserve_health_live() -> Json { let data = LiveResponse { live: true }; - Ok((HeaderMap::new(), Json(data)).into_response()) + Json(data) } #[utoipa::path( - post, + get, tag = "Text Generation Inference", path = "/v2/health/ready", responses( @@ -91,9 +89,9 @@ pub async fn kserve_health_live() -> Result Result)> { +pub async fn kserve_health_ready() -> Json { let data = ReadyResponse { live: true }; - Ok((HeaderMap::new(), Json(data)).into_response()) + Json(data) } #[utoipa::path( @@ -106,7 +104,7 @@ pub async fn kserve_health_ready() -> Result Result)> { +pub async fn kerve_server_metadata() -> Json { let data = MetadataServerResponse { name: "text-generation-inference".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), @@ -116,7 +114,7 @@ pub async fn kerve_server_metadata() -> Result Result, -) -> Result)> { +) -> Json { let data = MetadataServerResponse { name: model_name, version: model_version, extensions: vec!["infer".to_string(), "ready".to_string()], }; - Ok((HeaderMap::new(), Json(data)).into_response()) + Json(data) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/ready", + responses( + (status = 200, description = "Model version is ready", body = ReadyResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_metadata_ready( + Path((_model_name, _model_version)): Path<(String, String)>, +) -> Json { + let data = ReadyResponse { live: true }; + Json(data) } #[utoipa::path( @@ -155,7 +170,7 @@ pub async fn kserve_model_infer( infer: Extension, Extension(compute_type): Extension, Json(payload): Json, -) -> Result)> { +) -> Result)> { let id = payload.id.clone(); let str_inputs = payload .inputs @@ -226,22 +241,5 @@ pub async fn kserve_model_infer( outputs: output_chunks, }; - Ok((HeaderMap::new(), Json(inference_output)).into_response()) -} - -#[utoipa::path( - get, - tag = "Text Generation Inference", - path = "/v2/models/{model_name}/versions/{model_version}/ready", - responses( - (status = 200, description = "Model version is ready", body = ReadyResponse), - (status = 404, description = "Model or version not found", body = ErrorResponse, - example = json!({"error": "No response"})) - ) -)] -pub async fn kserve_model_metadata_ready( - Path((_model_name, _model_version)): Path<(String, String)>, -) -> Result)> { - let data = ReadyResponse { live: true }; - Ok((HeaderMap::new(), Json(data)).into_response()) + Ok((HeaderMap::new(), Json(inference_output))) } diff --git a/router/src/server.rs b/router/src/server.rs index 6e6b93b6..7f15bfdd 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1766,12 +1766,12 @@ pub async fn run( #[derive(OpenApi)] #[openapi( paths( - kserve_model_infer, kserve_health_live, kserve_health_ready, kerve_server_metadata, kserve_model_metadata, kserve_model_metadata_ready, + kserve_model_infer, ), components(schemas( InferenceOutput, From 7045598b20bb57d5b580abaf26131112bb7208fd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jun 2024 08:08:43 +0200 Subject: [PATCH 079/340] Fixing prom leak by upgrading. (#2129) --- Cargo.lock | 424 ++++++++++++++++++++++++++++++++++------------ router/Cargo.toml | 6 +- 2 files changed, 318 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4d537290..c285e07b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,9 +78,9 @@ dependencies = [ [[package]] name = "anstyle-query" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5" +checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" dependencies = [ "windows-sys 0.52.0", ] @@ -121,7 +121,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -160,7 +160,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -171,9 +171,15 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.3.0" @@ -226,6 +232,33 @@ dependencies = [ "slotmap", ] +[[package]] +name = "aws-lc-rs" +version = "1.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d844e282b4b56750b2d4e893b2205581ded8709fddd2b6aa5418c150ca877" +dependencies = [ + "aws-lc-sys", + "mirai-annotations", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a2c29203f6bf296d01141cc8bb9dbd5ecd4c27843f2ee0767bcd5985a927da" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "libc", + "paste", +] + [[package]] name = "axum" version = "0.6.20" @@ -239,7 +272,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.28", + "hyper 0.14.29", "itoa", "matchit", "memchr", @@ -381,6 +414,29 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.6.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.68", + "which", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -410,15 +466,15 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.3.0" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c12d1856e42f0d817a835fe55853957c85c8c8a470114029143d3f12671446e" +checksum = "415f8399438eb5e4b2f73ed3152a3448b98149dda642a957ee704e1daa5cf1d8" [[package]] name = "block-buffer" @@ -449,9 +505,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.16.0" +version = "1.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" +checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" [[package]] name = "byteorder" @@ -681,7 +737,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "crossterm_winapi", "libc", "mio", @@ -747,7 +803,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -758,7 +814,7 @@ checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" dependencies = [ "darling_core", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -788,7 +844,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -798,7 +854,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -854,9 +910,9 @@ dependencies = [ [[package]] name = "either" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encode_unicode" @@ -1016,6 +1072,12 @@ dependencies = [ "num", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures" version = "0.3.30" @@ -1072,7 +1134,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1182,6 +1244,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.1.0", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.4.1" @@ -1198,20 +1279,14 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] [[package]] name = "heck" @@ -1252,6 +1327,15 @@ dependencies = [ "ureq", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "hostname" version = "0.3.1" @@ -1308,12 +1392,12 @@ dependencies = [ [[package]] name = "http-body-util" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", - "futures-core", + "futures-util", "http 1.1.0", "http-body 1.0.0", "pin-project-lite", @@ -1364,6 +1448,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2 0.4.5", "http 1.1.0", "http-body 1.0.0", "httparse", @@ -1372,6 +1457,26 @@ dependencies = [ "pin-project-lite", "smallvec", "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.3.1", + "hyper-util", + "log", + "rustls 0.23.10", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", ] [[package]] @@ -1380,7 +1485,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.28", + "hyper 0.14.29", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1514,7 +1619,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1757,9 +1862,15 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "lebe" @@ -1784,6 +1895,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "libloading" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" +dependencies = [ + "cfg-if", + "windows-targets 0.52.5", +] + [[package]] name = "libm" version = "0.2.8" @@ -1796,7 +1917,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] @@ -1837,15 +1958,6 @@ dependencies = [ "imgref", ] -[[package]] -name = "mach2" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" -dependencies = [ - "libc", -] - [[package]] name = "macro_rules_attribute" version = "0.2.0" @@ -1911,16 +2023,29 @@ dependencies = [ ] [[package]] -name = "metrics-exporter-prometheus" -version = "0.12.2" +name = "metrics" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d4fa7ce7c4862db464a37b0b31d89bca874562f034bd7993895572783d02950" +checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" dependencies = [ - "base64 0.21.7", - "hyper 0.14.28", - "indexmap 1.9.3", + "ahash", + "portable-atomic", +] + +[[package]] +name = "metrics-exporter-prometheus" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf0af7a0d7ced10c0151f870e5e3f3f8bc9ffc5992d32873566ca1f9169ae776" +dependencies = [ + "base64 0.22.1", + "http-body-util", + "hyper 1.3.1", + "hyper-rustls", + "hyper-util", + "indexmap 2.2.6", "ipnet", - "metrics", + "metrics 0.23.0", "metrics-util", "quanta", "thiserror", @@ -1936,19 +2061,19 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] name = "metrics-util" -version = "0.15.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4de2ed6e491ed114b40b732e4d1659a9d53992ebd87490c44a6ffe23739d973e" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" dependencies = [ "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.13.1", - "metrics", + "hashbrown 0.14.5", + "metrics 0.23.0", "num_cpus", "quanta", "sketches-ddsketch", @@ -1997,9 +2122,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", "simd-adler32", @@ -2017,6 +2142,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mirai-annotations" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" + [[package]] name = "monostate" version = "0.1.13" @@ -2035,7 +2166,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2101,12 +2232,12 @@ dependencies = [ "bytes", "futures", "hostname", - "hyper 0.14.28", + "hyper 0.14.29", "muxado", "once_cell", "parking_lot", "regex", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "thiserror", @@ -2123,7 +2254,7 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", "cfg_aliases", "libc", @@ -2223,7 +2354,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2335,7 +2466,7 @@ version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", "foreign-types", "libc", @@ -2352,7 +2483,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2601,7 +2732,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2660,7 +2791,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2689,9 +2820,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -2712,7 +2843,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2752,7 +2883,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.66", + "syn 2.0.68", "tempfile", ] @@ -2802,13 +2933,12 @@ dependencies = [ [[package]] name = "quanta" -version = "0.11.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" dependencies = [ "crossbeam-utils", "libc", - "mach2", "once_cell", "raw-cpuid", "wasi", @@ -2867,7 +2997,7 @@ version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cassowary", "crossterm", "indoc", @@ -2915,9 +3045,9 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.5" +version = "0.11.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc13288f5ab39e6d7c9d501759712e6969fcc9734220846fc9ed26cae2cc4234" +checksum = "67376f469e7e7840d0040bbf4b9b3334005bb167f814621326e4c7ab8cd6e944" dependencies = [ "avif-serialize", "imgref", @@ -2930,11 +3060,11 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "10.7.0" +version = "11.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", ] [[package]] @@ -3043,10 +3173,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.28", + "hyper 0.14.29", "hyper-tls", "ipnet", "js-sys", @@ -3056,7 +3186,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", @@ -3152,6 +3282,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -3167,7 +3303,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -3200,6 +3336,34 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls" +version = "0.23.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.2", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -3209,6 +3373,16 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pemfile" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +dependencies = [ + "base64 0.22.1", + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.7.0" @@ -3221,6 +3395,7 @@ version = "0.102.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" dependencies = [ + "aws-lc-rs", "ring 0.17.8", "rustls-pki-types", "untrusted 0.9.0", @@ -3278,7 +3453,7 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -3321,14 +3496,14 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" dependencies = [ "itoa", "ryu", @@ -3386,6 +3561,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook" version = "0.3.17" @@ -3529,14 +3710,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -3551,9 +3732,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" dependencies = [ "proc-macro2", "quote", @@ -3739,7 +3920,7 @@ dependencies = [ "image", "init-tracing-opentelemetry", "jsonschema", - "metrics", + "metrics 0.21.1", "metrics-exporter-prometheus", "minijinja", "minijinja-contrib", @@ -3760,6 +3941,8 @@ dependencies = [ "tokio-stream", "tower-http", "tracing", + "tracing-core", + "tracing-log 0.2.0", "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", @@ -3784,7 +3967,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3921,7 +4104,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3945,6 +4128,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.10", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.15" @@ -4016,10 +4210,10 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.28", + "hyper 0.14.29", "hyper-timeout", "percent-encoding", "pin-project", @@ -4069,7 +4263,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -4098,7 +4292,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "bytes", "http 1.1.0", "http-body 1.0.0", @@ -4140,7 +4334,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -4395,7 +4589,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -4416,9 +4610,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.8.0" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" [[package]] name = "v_frame" @@ -4517,7 +4711,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", "wasm-bindgen-shared", ] @@ -4551,7 +4745,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4607,6 +4801,18 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "winapi" version = "0.3.9" @@ -4934,7 +5140,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] diff --git a/router/Cargo.toml b/router/Cargo.toml index 5bf4c00c..853f46b1 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -24,7 +24,7 @@ futures = "0.3.28" hf-hub = { workspace = true } jsonschema = { version = "0.17.1", features = ["draft202012"] } metrics = "0.21.1" -metrics-exporter-prometheus = { version = "0.12.1", features = [] } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.13.0" @@ -37,9 +37,9 @@ tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } -tracing = "0.1.37" +tracing = "0.1.40" tracing-opentelemetry = "0.21.0" -tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } From 11fced79bdffefff59fd019423cbb72cf63541aa Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jun 2024 12:34:43 +0200 Subject: [PATCH 080/340] Bumping to 2.1 (#2131) --- docs/source/installation_amd.md | 2 +- docs/source/installation_nvidia.md | 2 +- docs/source/quicktour.md | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index bf7f9c75..fe925e2a 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \ + ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \ --model-id $model ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 9077f7fd..11c41763 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.0.4 \ + ghcr.io/huggingface/text-generation-inference:2.1.0 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index b84de85d..09e56df4 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.0.4 \ + ghcr.io/huggingface/text-generation-inference:2.1.0 \ --model-id $model ``` @@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help +docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help ``` From d731866245631ad05cac34d39167219572f3db42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 27 Jun 2024 15:54:35 +0200 Subject: [PATCH 081/340] Idefics2: sync added image tokens with transformers (#2080) Before this change, the number of reserved image tokens was not the same as the number of images. Fixes #2029. While at it, also remove all the image token handling duplication in `prepare_input`. --- Cargo.lock | 1 + .../test_flash_idefics2_next_load.json | 11124 ++++++++-------- .../test_flash_idefics2_next_simple.json | 20 +- .../test_flash_idefics2_two_images.json | 38 +- router/Cargo.toml | 1 + router/src/config.rs | 12 +- router/src/lib.rs | 19 + router/src/main.rs | 12 +- router/src/server.rs | 8 +- router/src/validation.rs | 242 +- .../models/custom_modeling/llava_next.py | 9 +- .../models/pali_gemma.py | 4 +- .../models/vlm_causal_lm.py | 57 +- 13 files changed, 5887 insertions(+), 5660 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c285e07b..37c2553e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3919,6 +3919,7 @@ dependencies = [ "hf-hub", "image", "init-tracing-opentelemetry", + "itertools 0.10.5", "jsonschema", "metrics 0.21.1", "metrics-exporter-prometheus", diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json index 4bc90896..7f1875e0 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json @@ -37,7 +37,7 @@ }, { "id": 32001, - "logprob": -19.484375, + "logprob": -19.46875, "text": "" }, { @@ -57,7 +57,7 @@ }, { "id": 32001, - "logprob": -20.234375, + "logprob": -20.21875, "text": "" }, { @@ -65,11 +65,1785 @@ "logprob": -16.421875, "text": "" }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -23.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.0, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -3.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.2734375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -3.2988281, + "text": "" + }, + { + "id": 32001, + "logprob": -25.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.7421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -2.7207031, + "text": "" + }, + { + "id": 32001, + "logprob": -23.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, { "id": 32001, "logprob": -19.828125, "text": "" }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -3.0917969, + "text": "" + }, + { + "id": 32001, + "logprob": -25.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.3984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.5, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -2.0332031, + "text": "" + }, + { + "id": 12018, + "logprob": -12.078125, + "text": "Write" + }, + { + "id": 528, + "logprob": -10.09375, + "text": "me" + }, + { + "id": 264, + "logprob": -0.103393555, + "text": "a" + }, + { + "id": 2485, + "logprob": -4.5742188, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.23815918, + "text": "story" + }, + { + "id": 32002, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 259, + "logprob": -20.34375, + "text": " " + }, + { + "id": 13, + "logprob": -8.53125, + "text": "\n" + }, + { + "id": 7226, + "logprob": -10.4765625, + "text": "Ass" + }, + { + "id": 11143, + "logprob": -13.6015625, + "text": "istant" + }, + { + "id": 28747, + "logprob": -0.008514404, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.09289551, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.6743164, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.31396484, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.051727295, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.34448242, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.1194458, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.03237915, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.00018751621, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.07043457, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.00422287, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -5.2382812, + "text": "User" + }, + { + "id": 28747, + "logprob": -6.9492188, + "text": ":" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, { "id": 32001, "logprob": -23.25, @@ -102,7 +1876,7 @@ }, { "id": 32001, - "logprob": -21.015625, + "logprob": -21.03125, "text": "" }, { @@ -112,12 +1886,12 @@ }, { "id": 32001, - "logprob": -16.015625, + "logprob": -16.03125, "text": "" }, { "id": 32001, - "logprob": -19.0625, + "logprob": -19.046875, "text": "" }, { @@ -127,12 +1901,7 @@ }, { "id": 32001, - "logprob": -23.625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.40625, + "logprob": -23.609375, "text": "" }, { @@ -142,7 +1911,12 @@ }, { "id": 32001, - "logprob": -20.84375, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, "text": "" }, { @@ -152,12 +1926,12 @@ }, { "id": 32001, - "logprob": -19.984375, + "logprob": -19.96875, "text": "" }, { "id": 32001, - "logprob": -18.21875, + "logprob": -18.234375, "text": "" }, { @@ -167,7 +1941,7 @@ }, { "id": 32001, - "logprob": -21.203125, + "logprob": -21.1875, "text": "" }, { @@ -182,7 +1956,7 @@ }, { "id": 32001, - "logprob": -18.984375, + "logprob": -19.03125, "text": "" }, { @@ -197,17 +1971,17 @@ }, { "id": 32001, - "logprob": -18.0, + "logprob": -17.96875, "text": "" }, { "id": 32001, - "logprob": -18.828125, + "logprob": -18.859375, "text": "" }, { "id": 32001, - "logprob": -17.9375, + "logprob": -17.921875, "text": "" }, { @@ -217,12 +1991,12 @@ }, { "id": 32001, - "logprob": -18.640625, + "logprob": -18.65625, "text": "" }, { "id": 32001, - "logprob": -20.125, + "logprob": -20.140625, "text": "" }, { @@ -247,7 +2021,7 @@ }, { "id": 32001, - "logprob": -17.4375, + "logprob": -17.421875, "text": "" }, { @@ -257,7 +2031,7 @@ }, { "id": 32001, - "logprob": -23.015625, + "logprob": -23.0, "text": "" }, { @@ -282,7 +2056,7 @@ }, { "id": 32001, - "logprob": -18.40625, + "logprob": -18.421875, "text": "" }, { @@ -292,17 +2066,17 @@ }, { "id": 32001, - "logprob": -18.34375, + "logprob": -18.328125, "text": "" }, { "id": 32001, - "logprob": -17.140625, + "logprob": -17.125, "text": "" }, { "id": 32001, - "logprob": -18.671875, + "logprob": -18.65625, "text": "" }, { @@ -317,7 +2091,7 @@ }, { "id": 32001, - "logprob": -18.1875, + "logprob": -18.15625, "text": "" }, { @@ -337,12 +2111,7 @@ }, { "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.9375, + "logprob": -18.703125, "text": "" }, { @@ -351,883 +2120,38 @@ "text": "" }, { - "id": 32001, - "logprob": -19.125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.25, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.828125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4140625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.7265625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -14.2421875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -14.8515625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -18.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -20.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.5, - "text": "" - }, - { - "id": 32001, - "logprob": -20.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.9453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.703125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.6875, - "text": "" - }, - { - "id": 32001, - "logprob": -22.625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.75, - "text": "" - }, - { - "id": 32001, - "logprob": -21.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.828125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.75, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.375, - "text": "" + "id": 32000, + "logprob": -3.015625, + "text": "" }, { "id": 32001, - "logprob": -20.140625, + "logprob": -22.109375, "text": "" }, { "id": 32001, - "logprob": -21.140625, + "logprob": -18.96875, "text": "" }, { "id": 32001, - "logprob": -21.6875, + "logprob": -20.125, "text": "" }, { "id": 32001, - "logprob": -21.453125, + "logprob": -17.125, "text": "" }, { "id": 32001, - "logprob": -19.171875, + "logprob": -17.8125, "text": "" }, { "id": 32001, - "logprob": -17.78125, + "logprob": -19.3125, "text": "" }, { @@ -1237,157 +2161,27 @@ }, { "id": 32001, - "logprob": -17.078125, + "logprob": -16.3125, "text": "" }, { "id": 32001, - "logprob": -17.109375, + "logprob": -19.375, "text": "" }, { "id": 32001, - "logprob": -19.171875, + "logprob": -20.046875, "text": "" }, { "id": 32001, - "logprob": -20.453125, + "logprob": -20.828125, "text": "" }, { "id": 32001, - "logprob": -21.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.75, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8828125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.1171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0, - "text": "" - }, - { - "id": 32001, - "logprob": -20.75, + "logprob": -15.8046875, "text": "" }, { @@ -1397,7 +2191,12 @@ }, { "id": 32001, - "logprob": -19.46875, + "logprob": -19.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, "text": "" }, { @@ -1407,22 +2206,297 @@ }, { "id": 32001, - "logprob": -22.421875, + "logprob": -20.515625, "text": "" }, { "id": 32001, - "logprob": -20.9375, + "logprob": -19.171875, "text": "" }, { "id": 32001, - "logprob": -19.671875, + "logprob": -19.296875, "text": "" }, { "id": 32001, - "logprob": -20.890625, + "logprob": -16.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.2734375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -3.2988281, + "text": "" + }, + { + "id": 32001, + "logprob": -25.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.203125, "text": "" }, { @@ -1432,12 +2506,732 @@ }, { "id": 32001, - "logprob": -17.5, + "logprob": -15.75, "text": "" }, { "id": 32001, - "logprob": -17.90625, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -2.7207031, + "text": "" + }, + { + "id": 32001, + "logprob": -23.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -3.0917969, + "text": "" + }, + { + "id": 32001, + "logprob": -25.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, "text": "" }, { @@ -1452,37 +3246,37 @@ }, { "id": 32001, - "logprob": -16.40625, + "logprob": -16.375, "text": "" }, { "id": 32001, - "logprob": -18.453125, + "logprob": -18.4375, "text": "" }, { "id": 32001, - "logprob": -20.234375, + "logprob": -20.265625, "text": "" }, { "id": 32001, - "logprob": -22.28125, + "logprob": -22.296875, "text": "" }, { "id": 32001, - "logprob": -18.515625, + "logprob": -18.484375, "text": "" }, { "id": 32001, - "logprob": -15.4296875, + "logprob": -15.390625, "text": "" }, { "id": 32001, - "logprob": -19.765625, + "logprob": -19.75, "text": "" }, { @@ -1492,32 +3286,32 @@ }, { "id": 32001, - "logprob": -21.46875, + "logprob": -21.609375, "text": "" }, { "id": 32001, - "logprob": -18.875, + "logprob": -18.828125, "text": "" }, { "id": 32001, - "logprob": -20.859375, + "logprob": -20.828125, "text": "" }, { "id": 32001, - "logprob": -17.078125, + "logprob": -17.015625, "text": "" }, { "id": 32001, - "logprob": -16.4375, + "logprob": -16.40625, "text": "" }, { "id": 32001, - "logprob": -21.015625, + "logprob": -21.046875, "text": "" }, { @@ -1532,22 +3326,22 @@ }, { "id": 32001, - "logprob": -21.484375, + "logprob": -21.515625, "text": "" }, { "id": 32001, - "logprob": -20.015625, + "logprob": -20.0, "text": "" }, { "id": 32001, - "logprob": -18.84375, + "logprob": -18.78125, "text": "" }, { "id": 32001, - "logprob": -16.40625, + "logprob": -16.375, "text": "" }, { @@ -1557,57 +3351,57 @@ }, { "id": 32001, - "logprob": -16.65625, + "logprob": -16.703125, "text": "" }, { "id": 32001, - "logprob": -13.6328125, + "logprob": -13.625, "text": "" }, { "id": 32001, - "logprob": -15.4140625, + "logprob": -15.375, "text": "" }, { "id": 32001, - "logprob": -17.546875, + "logprob": -17.515625, "text": "" }, { "id": 32001, - "logprob": -21.859375, + "logprob": -21.921875, "text": "" }, { "id": 32001, - "logprob": -15.65625, + "logprob": -15.640625, "text": "" }, { "id": 32001, - "logprob": -16.484375, + "logprob": -16.46875, "text": "" }, { "id": 32001, - "logprob": -16.359375, + "logprob": -16.421875, "text": "" }, { "id": 32001, - "logprob": -19.9375, + "logprob": -19.890625, "text": "" }, { "id": 32001, - "logprob": -17.875, + "logprob": -17.890625, "text": "" }, { "id": 32001, - "logprob": -17.453125, + "logprob": -17.40625, "text": "" }, { @@ -1617,72 +3411,72 @@ }, { "id": 32001, - "logprob": -19.171875, + "logprob": -19.1875, "text": "" }, { "id": 32001, - "logprob": -15.9921875, + "logprob": -15.9609375, "text": "" }, { "id": 32000, - "logprob": -2.0429688, + "logprob": -2.0332031, "text": "" }, { "id": 12018, - "logprob": -12.03125, + "logprob": -12.078125, "text": "Write" }, { "id": 528, - "logprob": -10.25, + "logprob": -10.109375, "text": "me" }, { "id": 264, - "logprob": -0.10437012, + "logprob": -0.103515625, "text": "a" }, { "id": 2485, - "logprob": -4.5742188, + "logprob": -4.5664062, "text": "short" }, { "id": 2838, - "logprob": -0.2277832, + "logprob": -0.23864746, "text": "story" }, { "id": 32002, - "logprob": -10.84375, + "logprob": -10.9609375, "text": "" }, { "id": 259, - "logprob": -20.1875, + "logprob": -20.34375, "text": " " }, { "id": 13, - "logprob": -8.7578125, + "logprob": -8.5546875, "text": "\n" }, { "id": 7226, - "logprob": -10.421875, + "logprob": -10.484375, "text": "Ass" }, { "id": 11143, - "logprob": -13.640625, + "logprob": -13.6015625, "text": "istant" }, { "id": 28747, - "logprob": -0.005619049, + "logprob": -0.008308411, "text": ":" } ], @@ -1690,61 +3484,61 @@ "tokens": [ { "id": 330, - "logprob": -0.12939453, + "logprob": -0.09448242, "special": false, "text": " A" }, { "id": 13088, - "logprob": -0.6660156, + "logprob": -0.6743164, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.29638672, + "logprob": -0.31201172, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.05960083, + "logprob": -0.051635742, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.26953125, + "logprob": -0.34033203, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.1427002, + "logprob": -0.1194458, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.040649414, + "logprob": -0.032562256, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.0002708435, + "logprob": -0.00018763542, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.09429932, + "logprob": -0.07122803, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.006931305, + "logprob": -0.0041007996, "special": false, "text": "." } @@ -1766,12 +3560,12 @@ }, { "id": 1247, - "logprob": -5.234375, + "logprob": -5.2382812, "text": "User" }, { "id": 28747, - "logprob": -6.9648438, + "logprob": -6.9492188, "text": ":" }, { @@ -1781,12 +3575,12 @@ }, { "id": 32001, - "logprob": -18.96875, + "logprob": -18.984375, "text": "" }, { "id": 32001, - "logprob": -18.1875, + "logprob": -18.171875, "text": "" }, { @@ -1811,7 +3605,7 @@ }, { "id": 32001, - "logprob": -20.234375, + "logprob": -20.21875, "text": "" }, { @@ -1821,7 +3615,7 @@ }, { "id": 32001, - "logprob": -19.828125, + "logprob": -19.84375, "text": "" }, { @@ -1856,22 +3650,22 @@ }, { "id": 32001, - "logprob": -21.015625, + "logprob": -21.03125, "text": "" }, { "id": 32001, - "logprob": -20.4375, + "logprob": -20.421875, "text": "" }, { "id": 32001, - "logprob": -16.015625, + "logprob": -16.03125, "text": "" }, { "id": 32001, - "logprob": -19.0625, + "logprob": -19.046875, "text": "" }, { @@ -1886,7 +3680,7 @@ }, { "id": 32001, - "logprob": -20.40625, + "logprob": -20.421875, "text": "" }, { @@ -1896,7 +3690,7 @@ }, { "id": 32001, - "logprob": -20.84375, + "logprob": -20.875, "text": "" }, { @@ -1906,12 +3700,12 @@ }, { "id": 32001, - "logprob": -19.984375, + "logprob": -19.96875, "text": "" }, { "id": 32001, - "logprob": -18.21875, + "logprob": -18.234375, "text": "" }, { @@ -1921,7 +3715,7 @@ }, { "id": 32001, - "logprob": -21.203125, + "logprob": -21.1875, "text": "" }, { @@ -1936,7 +3730,7 @@ }, { "id": 32001, - "logprob": -18.984375, + "logprob": -19.03125, "text": "" }, { @@ -1951,17 +3745,17 @@ }, { "id": 32001, - "logprob": -18.0, + "logprob": -17.96875, "text": "" }, { "id": 32001, - "logprob": -18.828125, + "logprob": -18.859375, "text": "" }, { "id": 32001, - "logprob": -17.9375, + "logprob": -17.921875, "text": "" }, { @@ -1971,12 +3765,12 @@ }, { "id": 32001, - "logprob": -18.640625, + "logprob": -18.65625, "text": "" }, { "id": 32001, - "logprob": -20.125, + "logprob": -20.140625, "text": "" }, { @@ -1996,12 +3790,12 @@ }, { "id": 32001, - "logprob": -23.203125, + "logprob": -23.21875, "text": "" }, { "id": 32001, - "logprob": -17.4375, + "logprob": -17.421875, "text": "" }, { @@ -2011,7 +3805,7 @@ }, { "id": 32001, - "logprob": -23.015625, + "logprob": -23.0, "text": "" }, { @@ -2036,7 +3830,7 @@ }, { "id": 32001, - "logprob": -18.40625, + "logprob": -18.421875, "text": "" }, { @@ -2046,7 +3840,7 @@ }, { "id": 32001, - "logprob": -18.34375, + "logprob": -18.328125, "text": "" }, { @@ -2056,7 +3850,7 @@ }, { "id": 32001, - "logprob": -18.671875, + "logprob": -18.65625, "text": "" }, { @@ -2071,7 +3865,7 @@ }, { "id": 32001, - "logprob": -18.1875, + "logprob": -18.15625, "text": "" }, { @@ -2081,7 +3875,7 @@ }, { "id": 32001, - "logprob": -20.140625, + "logprob": -20.15625, "text": "" }, { @@ -2089,1643 +3883,29 @@ "logprob": -18.96875, "text": "" }, - { - "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.25, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.828125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4140625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.7265625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -14.2421875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -14.8515625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.203125, - "text": "" - }, { "id": 32001, "logprob": -18.703125, "text": "" }, - { - "id": 32001, - "logprob": -19.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -18.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -20.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.5, - "text": "" - }, - { - "id": 32001, - "logprob": -20.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.9453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.703125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.6875, - "text": "" - }, - { - "id": 32001, - "logprob": -22.625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.75, - "text": "" - }, - { - "id": 32001, - "logprob": -21.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.828125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.75, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.703125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.75, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.1171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0, - "text": "" - }, - { - "id": 32001, - "logprob": -20.75, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -22.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.5, - "text": "" - }, - { - "id": 32001, - "logprob": -17.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -22.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.75, - "text": "" - }, - { - "id": 32001, - "logprob": -16.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -22.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4296875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -14.6484375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -13.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.9921875, - "text": "" - }, - { - "id": 32000, - "logprob": -2.0429688, - "text": "" - }, - { - "id": 12018, - "logprob": -12.03125, - "text": "Write" - }, - { - "id": 528, - "logprob": -10.2578125, - "text": "me" - }, - { - "id": 264, - "logprob": -0.10418701, - "text": "a" - }, - { - "id": 2485, - "logprob": -4.5664062, - "text": "short" - }, - { - "id": 2838, - "logprob": -0.22741699, - "text": "story" - }, - { - "id": 32002, - "logprob": -10.8515625, - "text": "" - }, - { - "id": 259, - "logprob": -20.203125, - "text": " " - }, - { - "id": 13, - "logprob": -8.7421875, - "text": "\n" - }, - { - "id": 7226, - "logprob": -10.4140625, - "text": "Ass" - }, - { - "id": 11143, - "logprob": -13.6328125, - "text": "istant" - }, - { - "id": 28747, - "logprob": -0.005580902, - "text": ":" - } - ], - "seed": null, - "tokens": [ - { - "id": 330, - "logprob": -0.1295166, - "special": false, - "text": " A" - }, - { - "id": 13088, - "logprob": -0.6669922, - "special": false, - "text": " chicken" - }, - { - "id": 349, - "logprob": -0.29711914, - "special": false, - "text": " is" - }, - { - "id": 6398, - "logprob": -0.059936523, - "special": false, - "text": " sitting" - }, - { - "id": 356, - "logprob": -0.27124023, - "special": false, - "text": " on" - }, - { - "id": 264, - "logprob": -0.140625, - "special": false, - "text": " a" - }, - { - "id": 17972, - "logprob": -0.04058838, - "special": false, - "text": " pile" - }, - { - "id": 302, - "logprob": -0.00027012825, - "special": false, - "text": " of" - }, - { - "id": 2445, - "logprob": -0.09503174, - "special": false, - "text": " money" - }, - { - "id": 28723, - "logprob": -0.006942749, - "special": false, - "text": "." - } - ], - "top_tokens": null - }, - "generated_text": " A chicken is sitting on a pile of money." - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 1, - "logprob": null, - "text": "" - }, - { - "id": 1247, - "logprob": -5.2460938, - "text": "User" - }, - { - "id": 28747, - "logprob": -6.9570312, - "text": ":" - }, - { - "id": 32000, - "logprob": -16.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -23.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -22.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -23.625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.5, - "text": "" - }, - { - "id": 32001, - "logprob": -19.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -23.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.53125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.328125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, { "id": 32001, "logprob": -17.921875, "text": "" }, + { + "id": 32000, + "logprob": -3.015625, + "text": "" + }, { "id": 32001, - "logprob": -19.1875, + "logprob": -22.109375, "text": "" }, { "id": 32001, - "logprob": -18.640625, + "logprob": -18.96875, "text": "" }, { @@ -3733,76 +3913,6 @@ "logprob": -20.125, "text": "" }, - { - "id": 32001, - "logprob": -19.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -14.8828125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -23.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -23.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -23.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.75, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.34375, - "text": "" - }, { "id": 32001, "logprob": -17.125, @@ -3810,932 +3920,12 @@ }, { "id": 32001, - "logprob": -18.671875, + "logprob": -17.8125, "text": "" }, { "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.25, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8359375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4140625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.7265625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -14.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -14.8515625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.0, - "text": "" - }, - { - "id": 32001, - "logprob": -17.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -18.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -20.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.9453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.703125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.6875, - "text": "" - }, - { - "id": 32001, - "logprob": -22.625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.828125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.75, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.703125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.765625, + "logprob": -19.3125, "text": "" }, { @@ -4743,1385 +3933,11 @@ "logprob": -19.65625, "text": "" }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.75, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.1171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0, - "text": "" - }, - { - "id": 32001, - "logprob": -20.75, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -22.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.5, - "text": "" - }, - { - "id": 32001, - "logprob": -17.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -22.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.75, - "text": "" - }, - { - "id": 32001, - "logprob": -16.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -22.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.53125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4296875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -14.6484375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0, - "text": "" - }, - { - "id": 32001, - "logprob": -21.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.890625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -13.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4140625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.9921875, - "text": "" - }, - { - "id": 32000, - "logprob": -2.0429688, - "text": "" - }, - { - "id": 12018, - "logprob": -12.0390625, - "text": "Write" - }, - { - "id": 528, - "logprob": -10.25, - "text": "me" - }, - { - "id": 264, - "logprob": -0.10443115, - "text": "a" - }, - { - "id": 2485, - "logprob": -4.5742188, - "text": "short" - }, - { - "id": 2838, - "logprob": -0.22729492, - "text": "story" - }, - { - "id": 32002, - "logprob": -10.84375, - "text": "" - }, - { - "id": 259, - "logprob": -20.1875, - "text": " " - }, - { - "id": 13, - "logprob": -8.7578125, - "text": "\n" - }, - { - "id": 7226, - "logprob": -10.4140625, - "text": "Ass" - }, - { - "id": 11143, - "logprob": -13.6328125, - "text": "istant" - }, - { - "id": 28747, - "logprob": -0.0056533813, - "text": ":" - } - ], - "seed": null, - "tokens": [ - { - "id": 330, - "logprob": -0.12963867, - "special": false, - "text": " A" - }, - { - "id": 13088, - "logprob": -0.6660156, - "special": false, - "text": " chicken" - }, - { - "id": 349, - "logprob": -0.29516602, - "special": false, - "text": " is" - }, - { - "id": 6398, - "logprob": -0.060028076, - "special": false, - "text": " sitting" - }, - { - "id": 356, - "logprob": -0.27075195, - "special": false, - "text": " on" - }, - { - "id": 264, - "logprob": -0.1427002, - "special": false, - "text": " a" - }, - { - "id": 17972, - "logprob": -0.04067993, - "special": false, - "text": " pile" - }, - { - "id": 302, - "logprob": -0.000269413, - "special": false, - "text": " of" - }, - { - "id": 2445, - "logprob": -0.09387207, - "special": false, - "text": " money" - }, - { - "id": 28723, - "logprob": -0.0069236755, - "special": false, - "text": "." - } - ], - "top_tokens": null - }, - "generated_text": " A chicken is sitting on a pile of money." - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 1, - "logprob": null, - "text": "" - }, - { - "id": 1247, - "logprob": -5.2421875, - "text": "User" - }, - { - "id": 28747, - "logprob": -6.9570312, - "text": ":" - }, - { - "id": 32000, - "logprob": -16.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -23.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -22.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -23.625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.5, - "text": "" - }, - { - "id": 32001, - "logprob": -19.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -23.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.53125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.328125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -14.8828125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -23.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -23.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -23.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.75, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.34375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.0, - "text": "" - }, - { - "id": 32001, - "logprob": -18.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.25, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.921875, - "text": "" - }, { "id": 32001, "logprob": -16.3125, "text": "" }, - { - "id": 32001, - "logprob": -19.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.953125, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8359375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.03125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.96875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -15.4140625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.7265625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.5625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -14.2421875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.671875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.578125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.25, - "text": "" - }, - { - "id": 32001, - "logprob": -17.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8671875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.640625, - "text": "" - }, - { - "id": 32001, - "logprob": -14.8515625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.25, - "text": "" - }, - { - "id": 32001, - "logprob": -19.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.71875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.5, - "text": "" - }, - { - "id": 32001, - "logprob": -18.296875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -16.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.515625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -20.25, - "text": "" - }, - { - "id": 32001, - "logprob": -20.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.609375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.90625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.5, - "text": "" - }, - { - "id": 32001, - "logprob": -20.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.546875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.484375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.265625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -15.9453125, - "text": "" - }, - { - "id": 32001, - "logprob": -21.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.515625, - "text": "" - }, { "id": 32001, "logprob": -19.375, @@ -6129,152 +3945,27 @@ }, { "id": 32001, - "logprob": -17.796875, + "logprob": -20.046875, "text": "" }, { "id": 32001, - "logprob": -16.03125, + "logprob": -20.828125, "text": "" }, { "id": 32001, - "logprob": -18.671875, + "logprob": -15.8046875, "text": "" }, { "id": 32001, - "logprob": -20.15625, + "logprob": -16.25, "text": "" }, { "id": 32001, - "logprob": -20.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.84375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.78125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.234375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.9375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.703125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.15625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.203125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.6875, - "text": "" - }, - { - "id": 32001, - "logprob": -22.625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -18.46875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.546875, + "logprob": -19.953125, "text": "" }, { @@ -6284,37 +3975,42 @@ }, { "id": 32001, - "logprob": -16.453125, + "logprob": -21.59375, "text": "" }, { "id": 32001, - "logprob": -21.09375, + "logprob": -20.515625, "text": "" }, { "id": 32001, - "logprob": -19.5625, + "logprob": -19.171875, "text": "" }, { "id": 32001, - "logprob": -19.15625, + "logprob": -19.296875, "text": "" }, { "id": 32001, - "logprob": -16.171875, + "logprob": -16.71875, "text": "" }, { "id": 32001, - "logprob": -17.671875, + "logprob": -20.46875, "text": "" }, { "id": 32001, - "logprob": -18.859375, + "logprob": -21.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, "text": "" }, { @@ -6324,37 +4020,77 @@ }, { "id": 32001, - "logprob": -21.8125, + "logprob": -17.765625, "text": "" }, { "id": 32001, - "logprob": -19.96875, + "logprob": -20.328125, "text": "" }, { "id": 32001, - "logprob": -19.046875, + "logprob": -15.2734375, "text": "" }, { "id": 32001, - "logprob": -19.78125, + "logprob": -18.84375, "text": "" }, { "id": 32001, - "logprob": -19.421875, + "logprob": -17.875, "text": "" }, { "id": 32001, - "logprob": -21.21875, + "logprob": -15.578125, "text": "" }, { "id": 32001, - "logprob": -21.515625, + "logprob": -18.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, "text": "" }, { @@ -6364,37 +4100,12 @@ }, { "id": 32001, - "logprob": -20.734375, + "logprob": -19.546875, "text": "" }, { "id": 32001, - "logprob": -19.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -19.828125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.359375, - "text": "" - }, - { - "id": 32001, - "logprob": -17.75, + "logprob": -14.1953125, "text": "" }, { @@ -6404,182 +4115,42 @@ }, { "id": 32001, - "logprob": -18.765625, + "logprob": -18.1875, "text": "" }, { "id": 32001, - "logprob": -20.453125, + "logprob": -17.421875, "text": "" }, { "id": 32001, - "logprob": -19.890625, + "logprob": -20.421875, "text": "" }, { "id": 32001, - "logprob": -16.015625, + "logprob": -20.0, "text": "" }, { "id": 32001, - "logprob": -18.90625, + "logprob": -20.359375, "text": "" }, { "id": 32001, - "logprob": -15.953125, + "logprob": -18.03125, "text": "" }, { "id": 32001, - "logprob": -21.46875, + "logprob": -17.203125, "text": "" }, { "id": 32001, - "logprob": -19.984375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.875, - "text": "" - }, - { - "id": 32001, - "logprob": -18.859375, - "text": "" - }, - { - "id": 32001, - "logprob": -16.046875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.140625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.6875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.453125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.1875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.765625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.65625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -17.109375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.171875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.4375, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0625, - "text": "" - }, - { - "id": 32001, - "logprob": -16.734375, - "text": "" - }, - { - "id": 32001, - "logprob": -19.21875, - "text": "" - }, - { - "id": 32001, - "logprob": -16.421875, - "text": "" - }, - { - "id": 32001, - "logprob": -20.015625, - "text": "" - }, - { - "id": 32001, - "logprob": -17.796875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.3125, - "text": "" - }, - { - "id": 32001, - "logprob": -20.390625, - "text": "" - }, - { - "id": 32001, - "logprob": -19.28125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.59375, - "text": "" - }, - { - "id": 32001, - "logprob": -18.8125, - "text": "" - }, - { - "id": 32001, - "logprob": -19.09375, - "text": "" - }, - { - "id": 32001, - "logprob": -20.890625, + "logprob": -16.84375, "text": "" }, { @@ -6589,72 +4160,32 @@ }, { "id": 32001, - "logprob": -18.75, + "logprob": -15.71875, "text": "" }, { "id": 32001, - "logprob": -18.90625, + "logprob": -18.203125, "text": "" }, { "id": 32001, - "logprob": -21.375, + "logprob": -18.4375, "text": "" }, { "id": 32001, - "logprob": -16.640625, + "logprob": -18.46875, "text": "" }, { "id": 32001, - "logprob": -20.859375, + "logprob": -17.3125, "text": "" }, { "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -15.875, - "text": "" - }, - { - "id": 32001, - "logprob": -15.1171875, - "text": "" - }, - { - "id": 32001, - "logprob": -17.078125, - "text": "" - }, - { - "id": 32001, - "logprob": -18.921875, - "text": "" - }, - { - "id": 32001, - "logprob": -21.40625, - "text": "" - }, - { - "id": 32001, - "logprob": -21.0, - "text": "" - }, - { - "id": 32001, - "logprob": -20.75, - "text": "" - }, - { - "id": 32001, - "logprob": -16.25, + "logprob": -16.265625, "text": "" }, { @@ -6664,37 +4195,47 @@ }, { "id": 32001, - "logprob": -21.59375, + "logprob": -14.734375, "text": "" }, { "id": 32001, - "logprob": -22.421875, + "logprob": -20.6875, "text": "" }, { "id": 32001, - "logprob": -20.9375, + "logprob": -20.21875, "text": "" }, { "id": 32001, - "logprob": -19.671875, + "logprob": -18.359375, "text": "" }, { "id": 32001, - "logprob": -20.890625, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -3.2988281, + "text": "" + }, + { + "id": 32001, + "logprob": -25.75, "text": "" }, { "id": 32001, - "logprob": -16.921875, + "logprob": -18.421875, "text": "" }, { "id": 32001, - "logprob": -17.5, + "logprob": -19.265625, "text": "" }, { @@ -6704,42 +4245,77 @@ }, { "id": 32001, - "logprob": -22.1875, + "logprob": -17.203125, "text": "" }, { "id": 32001, - "logprob": -18.734375, + "logprob": -20.140625, "text": "" }, { "id": 32001, - "logprob": -16.40625, + "logprob": -17.96875, "text": "" }, { "id": 32001, - "logprob": -18.453125, + "logprob": -16.453125, "text": "" }, { "id": 32001, - "logprob": -20.234375, + "logprob": -19.65625, "text": "" }, { "id": 32001, - "logprob": -22.28125, + "logprob": -18.203125, "text": "" }, { "id": 32001, - "logprob": -18.515625, + "logprob": -16.921875, "text": "" }, { "id": 32001, - "logprob": -15.4296875, + "logprob": -15.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, "text": "" }, { @@ -6749,27 +4325,187 @@ }, { "id": 32001, - "logprob": -14.6484375, + "logprob": -19.890625, "text": "" }, { "id": 32001, - "logprob": -21.46875, + "logprob": -20.421875, "text": "" }, { "id": 32001, - "logprob": -18.875, + "logprob": -19.34375, "text": "" }, { "id": 32001, - "logprob": -20.859375, + "logprob": -20.140625, "text": "" }, { "id": 32001, - "logprob": -17.078125, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.703125, "text": "" }, { @@ -6777,11 +4513,581 @@ "logprob": -16.4375, "text": "" }, + { + "id": 32001, + "logprob": -19.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -2.7207031, + "text": "" + }, + { + "id": 32001, + "logprob": -23.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.203125, + "text": "" + }, { "id": 32001, "logprob": -21.015625, "text": "" }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -3.0917969, + "text": "" + }, + { + "id": 32001, + "logprob": -25.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, { "id": 32001, "logprob": -21.234375, @@ -6794,22 +5100,22 @@ }, { "id": 32001, - "logprob": -21.484375, + "logprob": -21.515625, "text": "" }, { "id": 32001, - "logprob": -20.015625, + "logprob": -20.0, "text": "" }, { "id": 32001, - "logprob": -18.84375, + "logprob": -18.78125, "text": "" }, { "id": 32001, - "logprob": -16.421875, + "logprob": -16.375, "text": "" }, { @@ -6819,57 +5125,57 @@ }, { "id": 32001, - "logprob": -16.65625, + "logprob": -16.703125, "text": "" }, { "id": 32001, - "logprob": -13.640625, + "logprob": -13.625, "text": "" }, { "id": 32001, - "logprob": -15.4140625, + "logprob": -15.375, "text": "" }, { "id": 32001, - "logprob": -17.546875, + "logprob": -17.515625, "text": "" }, { "id": 32001, - "logprob": -21.859375, + "logprob": -21.921875, "text": "" }, { "id": 32001, - "logprob": -15.65625, + "logprob": -15.640625, "text": "" }, { "id": 32001, - "logprob": -16.484375, + "logprob": -16.46875, "text": "" }, { "id": 32001, - "logprob": -16.359375, + "logprob": -16.421875, "text": "" }, { "id": 32001, - "logprob": -19.9375, + "logprob": -19.890625, "text": "" }, { "id": 32001, - "logprob": -17.875, + "logprob": -17.890625, "text": "" }, { "id": 32001, - "logprob": -17.453125, + "logprob": -17.40625, "text": "" }, { @@ -6879,72 +5185,72 @@ }, { "id": 32001, - "logprob": -19.171875, + "logprob": -19.1875, "text": "" }, { "id": 32001, - "logprob": -15.9921875, + "logprob": -15.9609375, "text": "" }, { "id": 32000, - "logprob": -2.0429688, + "logprob": -2.0332031, "text": "" }, { "id": 12018, - "logprob": -12.03125, + "logprob": -12.078125, "text": "Write" }, { "id": 528, - "logprob": -10.25, + "logprob": -10.109375, "text": "me" }, { "id": 264, - "logprob": -0.10437012, + "logprob": -0.103515625, "text": "a" }, { "id": 2485, - "logprob": -4.578125, + "logprob": -4.5664062, "text": "short" }, { "id": 2838, - "logprob": -0.22924805, + "logprob": -0.23864746, "text": "story" }, { "id": 32002, - "logprob": -10.84375, + "logprob": -10.9609375, "text": "" }, { "id": 259, - "logprob": -20.171875, + "logprob": -20.34375, "text": " " }, { "id": 13, - "logprob": -8.765625, + "logprob": -8.5546875, "text": "\n" }, { "id": 7226, - "logprob": -10.4140625, + "logprob": -10.484375, "text": "Ass" }, { "id": 11143, - "logprob": -13.640625, + "logprob": -13.6015625, "text": "istant" }, { "id": 28747, - "logprob": -0.005744934, + "logprob": -0.008308411, "text": ":" } ], @@ -6952,61 +5258,1835 @@ "tokens": [ { "id": 330, - "logprob": -0.12976074, + "logprob": -0.09448242, "special": false, "text": " A" }, { "id": 13088, - "logprob": -0.66308594, + "logprob": -0.6743164, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.29541016, + "logprob": -0.31201172, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.05996704, + "logprob": -0.051635742, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.27075195, + "logprob": -0.34033203, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.14160156, + "logprob": -0.1194458, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.040863037, + "logprob": -0.032562256, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.00027036667, + "logprob": -0.00018787384, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.093322754, + "logprob": -0.07122803, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.006931305, + "logprob": -0.0041007996, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -5.2382812, + "text": "User" + }, + { + "id": 28747, + "logprob": -6.9492188, + "text": ":" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -23.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.0, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -3.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8046875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.2734375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -3.2988281, + "text": "" + }, + { + "id": 32001, + "logprob": -25.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -2.7207031, + "text": "" + }, + { + "id": 32001, + "logprob": -23.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -3.0917969, + "text": "" + }, + { + "id": 32001, + "logprob": -25.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1328125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -13.625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -2.0332031, + "text": "" + }, + { + "id": 12018, + "logprob": -12.078125, + "text": "Write" + }, + { + "id": 528, + "logprob": -10.109375, + "text": "me" + }, + { + "id": 264, + "logprob": -0.103515625, + "text": "a" + }, + { + "id": 2485, + "logprob": -4.5664062, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.23864746, + "text": "story" + }, + { + "id": 32002, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 259, + "logprob": -20.34375, + "text": " " + }, + { + "id": 13, + "logprob": -8.5546875, + "text": "\n" + }, + { + "id": 7226, + "logprob": -10.484375, + "text": "Ass" + }, + { + "id": 11143, + "logprob": -13.6015625, + "text": "istant" + }, + { + "id": 28747, + "logprob": -0.008308411, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.09448242, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.6743164, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.31201172, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.051635742, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.34033203, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.1194458, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.032562256, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.00018763542, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.07122803, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.0041007996, "special": false, "text": "." } diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json index a3b18d0a..da2ac897 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json @@ -8,61 +8,61 @@ "tokens": [ { "id": 330, - "logprob": -0.13000488, + "logprob": -0.08660889, "special": false, "text": " A" }, { "id": 13088, - "logprob": -0.6713867, + "logprob": -0.7089844, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.2980957, + "logprob": -0.32885742, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.060638428, + "logprob": -0.05126953, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.27319336, + "logprob": -0.35229492, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.140625, + "logprob": -0.12561035, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.040405273, + "logprob": -0.038085938, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.0002708435, + "logprob": -0.00018656254, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.095336914, + "logprob": -0.07293701, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.0068359375, + "logprob": -0.004852295, "special": false, "text": "." } diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json index 86c95b29..bf2dc5a1 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -8,115 +8,115 @@ "tokens": [ { "id": 415, - "logprob": -0.04421997, + "logprob": -0.039886475, "special": false, "text": " The" }, { "id": 12072, - "logprob": -0.13500977, + "logprob": -0.1430664, "special": false, "text": " cow" }, { "id": 349, - "logprob": -0.06750488, + "logprob": -0.056488037, "special": false, "text": " is" }, { "id": 6328, - "logprob": -0.6352539, + "logprob": -0.6855469, "special": false, "text": " standing" }, { "id": 356, - "logprob": -0.16186523, + "logprob": -0.1685791, "special": false, "text": " on" }, { "id": 272, - "logprob": -0.5078125, + "logprob": -0.50097656, "special": false, "text": " the" }, { "id": 10305, - "logprob": -0.017913818, + "logprob": -0.017303467, "special": false, "text": " beach" }, { "id": 304, - "logprob": -1.5205078, + "logprob": -1.3564453, "special": false, "text": " and" }, { "id": 272, - "logprob": -0.029174805, + "logprob": -0.017868042, "special": false, "text": " the" }, { "id": 13088, - "logprob": -0.003479004, + "logprob": -0.0027103424, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.0035095215, + "logprob": -0.003156662, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.3088379, + "logprob": -0.37304688, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.027755737, + "logprob": -0.034576416, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.31884766, + "logprob": -0.29418945, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.047943115, + "logprob": -0.042877197, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.0002925396, + "logprob": -0.00028443336, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.02935791, + "logprob": -0.023223877, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.031219482, + "logprob": -0.018157959, "special": false, "text": "." }, { "id": 32002, - "logprob": -0.00034475327, + "logprob": -0.00018393993, "special": true, "text": "" }, diff --git a/router/Cargo.toml b/router/Cargo.toml index 853f46b1..5855ac86 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,6 +22,7 @@ text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { workspace = true } +itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.15.1", features = [] } diff --git a/router/src/config.rs b/router/src/config.rs index 29fefd5b..ccbdd8b2 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -71,10 +71,12 @@ fn get_unpadded_features( let current_aspect_ratio: f64 = current_width as f64 / current_height as f64; let (current_height, current_width) = if aspect_ratio > current_aspect_ratio { let new_height = (height * current_width) / width; - (new_height, current_width) + let padding = (current_height - new_height) / 2; + (current_height - (2 * padding), current_width) } else { let new_width = (width * current_height) / height; - (current_height, new_width) + let padding = (current_width - new_width) / 2; + (current_height, current_width - (2 * padding)) }; let unpadded_features = current_height * current_width; @@ -88,7 +90,9 @@ impl LlavaNext { let patch_size = self.vision_config.patch_size; assert!(image_size % patch_size == 0); let npatches = image_size / patch_size; - let (num_patch_height, num_patch_width) = + // Dimensions are intentionally swapped to be bug-compatible with + // upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); let (unpadded_features, newline_features) = @@ -112,7 +116,7 @@ pub struct Idefics2 {} impl Idefics2 { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { - 320 + 64 } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 126726c6..4ba76f5f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -70,6 +70,25 @@ impl HubTokenizerConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "processor_class")] +pub enum HubPreprocessorConfig { + Idefics2Processor(Idefics2Preprocessor), +} + +impl HubPreprocessorConfig { + pub fn from_file>(filename: P) -> Option { + let content = std::fs::read_to_string(filename).ok()?; + serde_json::from_str(&content).ok() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Idefics2Preprocessor { + #[serde(default)] + do_image_splitting: bool, +} + #[derive(Debug, Clone, Deserialize, Default)] pub struct HubProcessorConfig { pub chat_template: Option, diff --git a/router/src/main.rs b/router/src/main.rs index a7caec2e..68b6b1fc 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -13,7 +13,9 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_router::config::Config; -use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; +use text_generation_router::{ + server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, +}; use thiserror::Error; use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; @@ -214,6 +216,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_filename, config_filename, tokenizer_config_filename, + preprocessor_config_filename, processor_config_filename, model_info, ) = match api { @@ -221,6 +224,7 @@ async fn main() -> Result<(), RouterError> { Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("preprocessor_config.json")), Some(local_path.join("processor_config.json")), None, ), @@ -237,6 +241,7 @@ async fn main() -> Result<(), RouterError> { }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let model_info = if let Some(model_info) = get_model_info(&api_repo).await { @@ -249,6 +254,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_filename, config_filename, tokenizer_config_filename, + preprocessor_config_filename, processor_config_filename, model_info, ) @@ -263,6 +269,7 @@ async fn main() -> Result<(), RouterError> { repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), + repo.get("preprocessor_config.json"), repo.get("processor_config.json"), None, ) @@ -300,6 +307,8 @@ async fn main() -> Result<(), RouterError> { HubTokenizerConfig::default() }); + let preprocessor_config = + preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); let processor_config = processor_config_filename .and_then(HubProcessorConfig::from_file) .unwrap_or_default(); @@ -361,6 +370,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + preprocessor_config, processor_config, messages_api_enabled, disable_grammar_support, diff --git a/router/src/server.rs b/router/src/server.rs index 7f15bfdd..0cb08d4e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,9 +12,9 @@ use crate::kserve::{ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, - Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, + HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, + Token, TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1423,6 +1423,7 @@ pub async fn run( _ngrok_authtoken: Option, _ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + preprocessor_config: Option, processor_config: HubProcessorConfig, messages_api_enabled: bool, grammar_support: bool, @@ -1636,6 +1637,7 @@ pub async fn run( validation_workers, tokenizer, config, + preprocessor_config, max_best_of, max_stop_sequences, max_top_n_tokens, diff --git a/router/src/validation.rs b/router/src/validation.rs index e2bf5a5d..12cf2ab3 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,13 +1,16 @@ /// Payload validation logic use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; -use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use crate::{ + GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, +}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{io::Reader as ImageReader, ImageFormat}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; +use std::iter; use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -36,6 +39,7 @@ impl Validation { workers: usize, tokenizer: Option, config: Option, + preprocessor_config: Option, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, @@ -53,12 +57,18 @@ impl Validation { for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); let config_clone = config.clone(); + let preprocessor_config_clone = preprocessor_config.clone(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); senders.push(tokenizer_sender); // Spawn worker tokio::task::spawn_blocking(move || { - tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver) + tokenizer_worker( + tokenizer_clone, + config_clone, + preprocessor_config_clone, + tokenizer_receiver, + ) }); } @@ -422,13 +432,20 @@ async fn round_robin_task( fn tokenizer_worker( tokenizer: Tokenizer, config: Option, + preprocessor_config: Option, mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx - .send(prepare_input(inputs, truncate, &tokenizer, &config)) + .send(prepare_input( + inputs, + truncate, + &tokenizer, + config.as_ref(), + preprocessor_config.as_ref(), + )) .unwrap_or(()) }) } @@ -508,16 +525,67 @@ fn fetch_image(input: &str) -> Result<(Vec, String, usize, usize), Validatio } } +fn image_tokens( + config: &Config, + preprocessor_config: Option<&HubPreprocessorConfig>, + height: usize, + width: usize, +) -> String { + use Config::*; + use HubPreprocessorConfig::*; + match config { + Idefics => "".to_string(), + Idefics2(config) => { + const FAKE: &str = ""; + const IMAGE: &str = ""; + + let slots = config.get_number_of_features(height, width); + + let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len()); + image_string.push_str(FAKE); + image_string.extend(iter::repeat(IMAGE).take(slots)); + image_string.push_str(FAKE); + + if matches!( + preprocessor_config, + Some(Idefics2Processor(Idefics2Preprocessor { + do_image_splitting: true, + .. + })) + ) { + image_string = image_string.repeat(5); + }; + + image_string + } + Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), + LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), + _ => unimplemented!("Images tokens are not supported for this model configuration"), + } +} + +fn image_tokens_fixup(config: &Config, text: String) -> String { + match config { + Config::Idefics2(_) => { + const FAKE: &str = ""; + text.replace(&format!("{FAKE}{FAKE}"), FAKE) + } + _ => text, + } +} + /// Get input length and optionally truncate it fn prepare_input( inputs: String, _truncate: Option, tokenizer: &Tokenizer, - config: &Option, + config: Option<&Config>, + preprocessor_config: Option<&HubPreprocessorConfig>, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { + use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { - Some(Config::LlavaNext(config)) => { + Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; @@ -529,88 +597,17 @@ fn prepare_input( tokenizer_query.push_str(&inputs[start..chunk_start]); } let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = config.get_number_of_features(height, width); input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - tokenizer_query.push_str(&"".repeat(slots)); + tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); start = chunk_end; } if start != inputs.len() { input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - (tokenizer_query, input_chunks) - } - Some(Config::Paligemma(config)) => { - let mut input_chunks = Vec::new(); - let mut tokenizer_query = String::with_capacity(inputs.len()); - let mut start = 0; - for chunk in RE.find_iter(&inputs) { - let chunk_start = chunk.start(); - let chunk_end = chunk.end(); - if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); - tokenizer_query.push_str(&inputs[start..chunk_start]); - } - let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = config.get_number_of_features(height, width); - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - tokenizer_query.push_str(&"".repeat(slots)); - start = chunk_end; - } - if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); - tokenizer_query.push_str(&inputs[start..]); - } - (tokenizer_query, input_chunks) - } - Some(Config::Idefics2(config)) => { - let mut input_chunks = Vec::new(); - let mut tokenizer_query = String::with_capacity(inputs.len()); - let mut start = 0; - for chunk in RE.find_iter(&inputs) { - let chunk_start = chunk.start(); - let chunk_end = chunk.end(); - if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); - tokenizer_query.push_str(&inputs[start..chunk_start]); - } - let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = config.get_number_of_features(height, width); - tokenizer_query.push_str(""); - tokenizer_query.push_str(&"".repeat(slots)); - tokenizer_query.push_str(""); - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - start = chunk_end; - } - if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); - tokenizer_query.push_str(&inputs[start..]); - } - (tokenizer_query, input_chunks) - } - Some(Config::Idefics) => { - let mut input_chunks = Vec::new(); - let mut tokenizer_query = String::with_capacity(inputs.len()); - let mut start = 0; - for chunk in RE.find_iter(&inputs) { - let chunk_start = chunk.start(); - let chunk_end = chunk.end(); - if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); - tokenizer_query.push_str(&inputs[start..chunk_start]); - } - let (data, mimetype, _height, _width) = - fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = 1; - tokenizer_query.push_str(&"".repeat(slots)); - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - start = chunk_end; - } - if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); - tokenizer_query.push_str(&inputs[start..]); - } + tokenizer_query = image_tokens_fixup(config, tokenizer_query); + (tokenizer_query, input_chunks) } _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), @@ -750,7 +747,7 @@ pub enum ValidationError { #[cfg(test)] mod tests { use super::*; - use crate::config::{PaliTextConfig, Paligemma}; + use crate::config::{Idefics2, PaliTextConfig, Paligemma}; use crate::default_parameters; use crate::tests::get_tokenizer; @@ -769,6 +766,7 @@ mod tests { workers, tokenizer, config, + None, max_best_of, max_stop_sequence, max_top_n_tokens, @@ -803,6 +801,7 @@ mod tests { workers, tokenizer, config, + None, max_best_of, max_stop_sequence, max_top_n_tokens, @@ -836,6 +835,7 @@ mod tests { workers, tokenizer, config, + None, max_best_of, max_stop_sequence, max_top_n_tokens, @@ -874,6 +874,7 @@ mod tests { workers, tokenizer, config, + None, max_best_of, max_stop_sequence, max_top_n_tokens, @@ -941,6 +942,7 @@ mod tests { workers, tokenizer, config, + None, max_best_of, max_stop_sequences, max_top_n_tokens, @@ -1026,6 +1028,7 @@ mod tests { workers, tokenizer, Some(config), + None, max_best_of, max_stop_sequence, max_top_n_tokens, @@ -1058,4 +1061,83 @@ mod tests { "Failed to process images", ); } + + #[tokio::test] + async fn test_idefics2_correct_n_fake_tokens() { + let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); + + let tokenizer = Some(get_tokenizer().await); + + let max_best_of = 2; + let max_stop_sequence = 3; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; + let disable_grammar_support = true; + let workers = 1; + let config = Config::Idefics2(Idefics2 {}); + let validation = Validation::new( + workers, + tokenizer, + Some(config), + Some(HubPreprocessorConfig::Idefics2Processor( + Idefics2Preprocessor { + do_image_splitting: true, + }, + )), + max_best_of, + max_stop_sequence, + max_top_n_tokens, + max_input_length, + max_total_tokens, + disable_grammar_support, + ); + + let (encoding, chunks) = match validation + .tokenize( + format!( + "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", + PIXEL_GIF, PIXEL_GIF + ), + None, + ) + .await + { + Ok(Some((encoding, chunks))) => (encoding, chunks), + _ => panic!("Unexpected tokenization failure"), + }; + + assert!( + chunks + == vec![ + Chunk::Text("test".to_string()).into(), + Chunk::Image(Image { + data: pixel_data.clone(), + mimetype: "image/gif".to_string() + }) + .into(), + Chunk::Image(Image { + data: pixel_data.clone(), + mimetype: "image/gif".to_string() + }) + .into() + ], + "Failed to process images", + ); + + // Verify the number of fake tokens: + // + // - Two images surrounded/separated by a fake token = 3. + // - Both are split in 5 subimages, separated by a fake token: 2 * 4 + // + // Fake tokens get split up by the testing tokenizer, but we don't care. + assert_eq!( + encoding + .get_tokens() + .iter() + .filter(|t| *t == "fake") + .count(), + 11 + ); + } } diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 9a670140..6d38442c 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -39,7 +39,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Args: image_size (`tuple`): - The size of the input image in the format (width, height). + The size of the input image in the format (height, width). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. @@ -47,7 +47,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): The size of each image patch. Returns: - tuple: The shape of the image patch grid in the format (width, height). + tuple: The shape of the image patch grid in the format (height, width). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") @@ -230,7 +230,10 @@ class LlavaNextForConditionalGeneration(nn.Module): raise ValueError( "The number of patches is not consistent with the image size." ) - num_patch_height, num_patch_width = get_anyres_image_grid_shape( + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index e883ce02..a167e467 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -39,7 +39,9 @@ class PaliGemmaBatch(VlmCausalLMBatch): # TODO do_convert_RGB should be on by default ? image = image.convert("RGB") image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement(image_input, config, image_id) + full_text += image_text_replacement( + processor, image_input, config, image_id + ) image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk_type}") diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 218d1167..1cdf37ea 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,3 +1,4 @@ +from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -15,6 +16,9 @@ from text_generation_server.models.flash_mistral import ( tracer = trace.get_tracer(__name__) +IDEFICS2_FAKE_TOKEN = "" +IDEFICS2_IMAGE_TOKEN = "" + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -22,7 +26,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Args: image_size (`tuple`): - The size of the input image in the format (width, height). + The size of the input image in the format (height, width). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. @@ -39,15 +43,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def image_text_replacement(image_input, config, image_id) -> str: +def image_text_replacement(processor, image_input, config, image_id: int) -> str: if config.model_type == "idefics2": - # TODO technically depends on image splitting which is not implemented. - num_features = 320 - return ( - "" - + "" * num_features - + "" - ) + image_seq_len = 64 + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" + if processor.image_processor.do_image_splitting: + image_str *= 5 + return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) @@ -64,20 +66,35 @@ def image_text_replacement(image_input, config, image_id) -> str: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") +def image_text_replacement_fixup(config, text: str) -> str: + if config.model_type == "idefics2": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) + return text + + def get_unpadded_features( - height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width - aspect_ratio: float = width / height + aspect_ratio: float = original_width / original_height current_aspect_ratio: float = current_width / current_height + if aspect_ratio > current_aspect_ratio: - new_height = (height * current_width) // width - current_height = new_height + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) else: - new_width = (width * current_height) // height - current_width = new_width + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) unpadded_features = current_height * current_width newline_features = current_height @@ -96,7 +113,9 @@ def get_number_of_features(height: int, width: int, config) -> int: npatches = image_size // patch_size - num_patch_height, num_patch_width = get_anyres_image_grid_shape( + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( [height, width], image_grid_pinpoints, image_size, @@ -168,9 +187,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - full_text += image_text_replacement(image_inputs, config, image_id) + full_text += image_text_replacement( + processor, image_inputs, config, image_id + ) image_id += 1 + full_text = image_text_replacement_fixup(config, full_text) + batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) From befe60b566d47788fd4918847faa6cafe28fbbc2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jun 2024 16:04:03 +0200 Subject: [PATCH 082/340] Fixing malformed rust tokenizers (#2134) * Fixing malformed rust tokenizers * Fix for deepseek too. --- Cargo.lock | 2 -- router/src/config.rs | 1 + router/src/lib.rs | 3 +++ router/src/main.rs | 35 ++++++++++++++++++++++++++++++++--- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 37c2553e..50799b9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3942,8 +3942,6 @@ dependencies = [ "tokio-stream", "tower-http", "tracing", - "tracing-core", - "tracing-log 0.2.0", "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", diff --git a/router/src/config.rs b/router/src/config.rs index ccbdd8b2..7737165e 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -162,6 +162,7 @@ pub enum Config { Baichuan, Paligemma(Paligemma), Gemma, + Gemma2, Cohere, Drbx, Falcon, diff --git a/router/src/lib.rs b/router/src/lib.rs index 4ba76f5f..a5b97af3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -61,6 +61,9 @@ pub struct HubTokenizerConfig { pub bos_token: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub eos_token: Option, + pub tokenizer_class: Option, + pub add_bos_token: Option, + pub add_eos_token: Option, } impl HubTokenizerConfig { diff --git a/router/src/main.rs b/router/src/main.rs index 68b6b1fc..3aa5a6bf 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -17,7 +17,7 @@ use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; use thiserror::Error; -use tokenizers::Tokenizer; +use tokenizers::{processors::template::TemplateProcessing, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -275,8 +275,6 @@ async fn main() -> Result<(), RouterError> { ) } }; - let tokenizer: Option = - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) .ok() @@ -306,6 +304,37 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); + let tokenizer: Option = + tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer{ + if let Some(class) = &tokenizer_config.tokenizer_class{ + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" { + tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + let mut single = vec![]; + let mut special_tokens = vec![]; + if let Some(true) = &tokenizer_config.add_bos_token{ + if let Some(bos_token) = &tokenizer_config.bos_token{ + let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id"); + special_tokens.push((bos_token.clone(), bos_token_id)); + single.push(bos_token.to_string()); + } + } + single.push("$0".to_string()); + if let Some(true) = &tokenizer_config.add_eos_token{ + if let Some(eos_token) = &tokenizer_config.eos_token{ + let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id"); + special_tokens.push((eos_token.clone(), eos_token_id)); + single.push(eos_token.to_string()); + } + } + let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap(); + tokenizer.with_post_processor(post_processor); + }} + } + tokenizer + + }); let preprocessor_config = preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); From bc15e960eabad6240d4a3fad55b66d600cc2e1c2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jun 2024 16:04:20 +0200 Subject: [PATCH 083/340] Fixing gemma2. (#2135) * Fixing gemma2. * Adding new model. --- docs/source/supported_models.md | 1 + .../text_generation_server/models/__init__.py | 30 ++ .../custom_modeling/flash_gemma2_modeling.py | 500 ++++++++++++++++++ .../custom_modeling/flash_gemma_modeling.py | 2 - .../models/flash_causal_lm.py | 18 +- .../models/flash_gemma2.py | 75 +++ .../text_generation_server/models/globals.py | 5 + 7 files changed, 622 insertions(+), 9 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py create mode 100644 server/text_generation_server/models/flash_gemma2.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 3468e988..1eeed39f 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) +- [Gemma2](https://huggingface.co/google/gemma2-9b) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 648fcee9..f2f0f457 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -68,6 +68,9 @@ try: from text_generation_server.models.flash_gemma import ( FlashGemma, ) + from text_generation_server.models.flash_gemma2 import ( + FlashGemma2, + ) from text_generation_server.models.pali_gemma import ( PaliGemma, ) @@ -102,6 +105,7 @@ if FLASH_ATTENTION: __all__.append(FlashQwen2) __all__.append(FlashStarcoder2) __all__.append(FlashGemma) + __all__.append(FlashGemma2) __all__.append(FlashCohere) MAMBA_AVAILABLE = True @@ -143,6 +147,11 @@ class ModelType(enum.Enum): "name": "Gemma", "url": "https://huggingface.co/google/gemma-7b", } + GEMMA2 = { + "type": "gemma2", + "name": "Gemma2", + "url": "https://huggingface.co/google/gemma2-9b", + } COHERE = { "type": "cohere", "name": "Cohere", @@ -630,6 +639,27 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == GEMMA2: + if FLASH_ATTENTION: + return FlashGemma2( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == COHERE: if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py new file mode 100644 index 00000000..a71de61f --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -0,0 +1,500 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) + + +class Gemma2Config(PretrainedConfig): + def __init__( + self, + vocab_size=256128, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Gemma2FastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix, weights, eps=1e-6): + dtype = weights.dtype + weights.dtype = torch.float32 + weight = weights.get_tensor(f"{prefix}.weight") + 1 + weights.dtype = dtype + new = cls(weight, eps) + new.dtype = dtype + return new + + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.dtype), residual + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq", "marlin"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashGemma2Attention(torch.nn.Module): + def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + self.causal = causal + if is_sliding: + self.window_size = config.sliding_window + else: + self.window_size = -1 + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + # self.softmax_scale = self.head_size**-0.5 + self.softmax_scale = config.query_pre_attn_scalar**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + causal=self.causal, + window_size_left=self.window_size, + ) + # Decode + else: + paged_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class Gemma2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashGemma2Layer(nn.Module): + def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): + super().__init__() + self.self_attn = FlashGemma2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + causal=causal, + is_sliding=is_sliding, + ) + self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.pre_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.post_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, _ = self.post_attention_layernorm(attn_output) + normed_attn_res_output = normed_attn_res_output + res + res = normed_attn_res_output + + pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output) + mlp_output = self.mlp(pre_normed) + post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output) + + return post_hidden_states, normed_attn_res_output + + +class FlashGemma2Model(torch.nn.Module): + def __init__(self, prefix, config, weights, causal: bool): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + FlashGemma2Layer( + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + causal=causal, + is_sliding=layer_id % 2 == 0, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = inputs_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemma2ForCausalLM(torch.nn.Module): + def __init__(self, prefix, config, weights, causal: bool): + super().__init__() + + embed_norm = config.hidden_size**0.5 + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemma2Model( + prefix=prefix, config=config, weights=weights, causal=causal + ) + self.lm_head = SpeculativeHead.load( + prefix=( + f"{prefix}.embed_tokens" + if config.tie_word_embeddings + else f"{prefix}.lm_head" + ), + config=config, + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + input_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index a4fd4740..82891823 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -375,8 +375,6 @@ class FlashGemmaModel(torch.nn.Module): prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) - self.gradient_checkpointing = False - self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f7678762..a0a78b33 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -28,8 +28,12 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS -import text_generation_server.models.globals as tgi_globals +from text_generation_server.models.globals import ( + MEM_POOL, + CUDA_GRAPHS, + get_adapter_to_index, + MODEL_ID, +) from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments @@ -233,7 +237,8 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0) + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_indices_list.append(torch.full((input_length,), adapter_index)) adapter_set.add(adapter_index) @@ -499,9 +504,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) - adapter_index = tgi_globals.ADAPTER_TO_INDEX.get( - self.requests[idx].adapter_id, 0 - ) + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) remaining_tokens = ( @@ -1017,7 +1021,7 @@ class FlashCausalLM(Model): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) logger.info( diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py new file mode 100644 index 00000000..9608113b --- /dev/null +++ b/server/text_generation_server/models/flash_gemma2.py @@ -0,0 +1,75 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from typing import Optional +from transformers import PretrainedConfig, AutoTokenizer + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashGemma2(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError("FlashGemma2 is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = PretrainedConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + # TODO hardcoded + prefix = "" + model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) + + torch.distributed.barrier(group=self.process_group) + super(FlashGemma2, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index cc2f172a..bde86e6e 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -44,3 +44,8 @@ ADAPTER_TO_INDEX: Dict[str, int] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX ADAPTER_TO_INDEX = adapter_to_index + + +def get_adapter_to_index(): + global ADAPTER_TO_INDEX + return ADAPTER_TO_INDEX From 69514868eee7894cd9d3e8655cd5929977a3c0a2 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 27 Jun 2024 17:16:19 -0400 Subject: [PATCH 084/340] fix: refactor post_processor logic and add test (#2137) * fix: refactor post_processor logic and add test * fix: remove dev comment * fix: adjust when post_processor is overridden and improve create_post_processor --- router/src/main.rs | 148 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 118 insertions(+), 30 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 3aa5a6bf..1e8093d8 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -304,37 +304,21 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); - let tokenizer: Option = - tokenizer_filename.and_then(|filename| { - let mut tokenizer = Tokenizer::from_file(filename).ok(); - if let Some(tokenizer) = &mut tokenizer{ - if let Some(class) = &tokenizer_config.tokenizer_class{ - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" { - tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - let mut single = vec![]; - let mut special_tokens = vec![]; - if let Some(true) = &tokenizer_config.add_bos_token{ - if let Some(bos_token) = &tokenizer_config.bos_token{ - let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id"); - special_tokens.push((bos_token.clone(), bos_token_id)); - single.push(bos_token.to_string()); - } - } - single.push("$0".to_string()); - if let Some(true) = &tokenizer_config.add_eos_token{ - if let Some(eos_token) = &tokenizer_config.eos_token{ - let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id"); - special_tokens.push((eos_token.clone(), eos_token_id)); - single.push(eos_token.to_string()); - } - } - let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap(); - tokenizer.with_post_processor(post_processor); - }} - } - tokenizer - }); + let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer { + if let Some(class) = &tokenizer_config.tokenizer_class { + if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() { + if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + tokenizer.with_post_processor(post_processor); + } + } + } + } + tokenizer + }); let preprocessor_config = preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); @@ -543,6 +527,77 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option Result { + let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); + let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); + + let bos_token = tokenizer_config.bos_token.as_ref(); + let eos_token = tokenizer_config.eos_token.as_ref(); + + if add_bos_token && bos_token.is_none() { + panic!("add_bos_token = true but bos_token is None"); + } + + if add_eos_token && eos_token.is_none() { + panic!("add_eos_token = true but eos_token is None"); + } + + let mut single = Vec::new(); + let mut pair = Vec::new(); + let mut special_tokens = Vec::new(); + + if add_bos_token { + if let Some(bos) = bos_token { + let bos_token_id = tokenizer + .token_to_id(bos) + .expect("Should have found the bos token id"); + special_tokens.push((bos.clone(), bos_token_id)); + single.push(format!("{}:0", bos)); + pair.push(format!("{}:0", bos)); + } + } + + single.push("$A:0".to_string()); + pair.push("$A:0".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + let eos_token_id = tokenizer + .token_to_id(eos) + .expect("Should have found the eos token id"); + special_tokens.push((eos.clone(), eos_token_id)); + single.push(format!("{}:0", eos)); + pair.push(format!("{}:0", eos)); + } + } + + if add_bos_token { + if let Some(bos) = bos_token { + single.push(format!("{}:1", bos)); + } + } + + pair.push("$B:1".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + pair.push(format!("{}:1", eos)); + } + } + + let post_processor = TemplateProcessing::builder() + .try_single(single)? + .try_pair(pair)? + .special_tokens(special_tokens) + .build()?; + + Ok(post_processor) +} + #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")] @@ -552,3 +607,36 @@ enum RouterError { #[error("Tokio runtime failed to start: {0}")] Tokio(#[from] std::io::Error), } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_post_processor() { + let tokenizer_config = HubTokenizerConfig { + add_bos_token: None, + add_eos_token: None, + bos_token: Some("".to_string()), + eos_token: Some("".to_string()), + chat_template: None, + tokenizer_class: None, + completion_template: None, + }; + + let tokenizer = + Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap(); + let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); + + let expected = TemplateProcessing::builder() + .try_single(":0 $A:0 :1") + .unwrap() + .try_pair(":0 $A:0 $B:1") + .unwrap() + .special_tokens(vec![("".to_string(), 1)]) + .build() + .unwrap(); + + assert_eq!(post_processor, expected); + } +} From 8721b601e32113d3f0591461f37eb2c71dc5ceac Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 1 Jul 2024 17:27:53 +0800 Subject: [PATCH 085/340] =?UTF-8?q?fix=20microsoft/Phi-3-mini-4k-instruct?= =?UTF-8?q?=20crash=20in=20batch.slots[batch.slot=5F=E2=80=A6=20(#2148)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix microsoft/Phi-3-mini-4k-instruct crash in batch.slots[batch.slot_indices] Signed-off-by: Wang, Yi A * Apply suggestions from code review --------- Signed-off-by: Wang, Yi A Co-authored-by: Nicolas Patry --- router/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 1e8093d8..a8651b67 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -17,7 +17,7 @@ use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; use thiserror::Error; -use tokenizers::{processors::template::TemplateProcessing, Tokenizer}; +use tokenizers::{processors::template::TemplateProcessing, Tokenizer, PostProcessor}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -309,7 +309,7 @@ async fn main() -> Result<(), RouterError> { let mut tokenizer = Tokenizer::from_file(filename).ok(); if let Some(tokenizer) = &mut tokenizer { if let Some(class) = &tokenizer_config.tokenizer_class { - if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() { + if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast"){ if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); tokenizer.with_post_processor(post_processor); @@ -577,7 +577,7 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { - single.push(format!("{}:1", bos)); + pair.push(format!("{}:1", bos)); } } From 03691f6d345aa49a2f27bbadad03137647f5d9eb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 12:02:19 +0200 Subject: [PATCH 086/340] Fixing clippy. (#2149) --- router/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index a8651b67..8a5cf459 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -17,7 +17,7 @@ use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; use thiserror::Error; -use tokenizers::{processors::template::TemplateProcessing, Tokenizer, PostProcessor}; +use tokenizers::{processors::template::TemplateProcessing, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -309,7 +309,7 @@ async fn main() -> Result<(), RouterError> { let mut tokenizer = Tokenizer::from_file(filename).ok(); if let Some(tokenizer) = &mut tokenizer { if let Some(class) = &tokenizer_config.tokenizer_class { - if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast"){ + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); tokenizer.with_post_processor(post_processor); From 3e02d4fdbf3143beb28bcf3f2d5adb05a6494469 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 1 Jul 2024 06:58:40 -0400 Subject: [PATCH 087/340] fix: use weights from base_layer (#2141) --- .../models/custom_modeling/flash_llama_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c48ed268..6b82aeca 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -309,7 +309,9 @@ class LlamaMLP(nn.Module): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + _custom_C.LLMM_Silu( + self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 + ) return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) From de96056c2662ece06a2f4c56edf9df6ab30664fb Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 1 Jul 2024 06:58:49 -0400 Subject: [PATCH 088/340] feat: download lora adapter weights from launcher (#2140) --- launcher/src/main.rs | 46 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e6928958..04358917 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -902,13 +902,20 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { +fn download_convert_model( + model_id: &str, + revision: Option<&str>, + trust_remote_code: bool, + huggingface_hub_cache: Option<&str>, + weights_cache_override: Option<&str>, + running: Arc, +) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let mut download_args = vec![ "download-weights".to_string(), - args.model_id.to_string(), + model_id.to_string(), "--extension".to_string(), ".safetensors".to_string(), "--logger-level".to_string(), @@ -917,13 +924,13 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L ]; // Model optional revision - if let Some(revision) = &args.revision { + if let Some(revision) = &revision { download_args.push("--revision".to_string()); download_args.push(revision.to_string()) } // Trust remote code for automatic peft fusion - if args.trust_remote_code { + if trust_remote_code { download_args.push("--trust-remote-code".to_string()); } @@ -938,7 +945,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // If huggingface_hub_cache is set, pass it to the download process // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; @@ -956,7 +963,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // If args.weights_cache_override is some, pass it to the download process // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = &args.weights_cache_override { + if let Some(weights_cache_override) = &weights_cache_override { envs.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), @@ -964,7 +971,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L }; // Start process - tracing::info!("Starting download process."); + tracing::info!("Starting check and download process for {model_id}"); let mut download_process = match Command::new("text-generation-server") .args(download_args) .env_clear() @@ -1006,7 +1013,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L loop { if let Some(status) = download_process.try_wait().unwrap() { if status.success() { - tracing::info!("Successfully downloaded weights."); + tracing::info!("Successfully downloaded weights for {model_id}"); break; } @@ -1571,7 +1578,28 @@ fn main() -> Result<(), LauncherError> { .expect("Error setting Ctrl-C handler"); // Download and convert model weights - download_convert_model(&args, running.clone())?; + download_convert_model( + &args.model_id, + args.revision.as_deref(), + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + + // Download and convert lora adapters if any + if let Some(lora_adapters) = &args.lora_adapters { + for adapter in lora_adapters.split(',') { + download_convert_model( + adapter, + None, + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + } + } if !running.load(Ordering::SeqCst) { // Launcher was asked to stop From e0d168ba20eb7ad27bdbadfb891a7c31629e34da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 1 Jul 2024 12:59:12 +0200 Subject: [PATCH 089/340] Use GPTQ-Marlin for supported GPTQ configurations (#2111) GPTQ-Marlin is currently the best-performing kernel for GPTQ models. So let's use it by default if the kernels are installed, the GPU supports it, and the kernels support the configuration. For models generated by `text-generation-server quantize`, use `sym=False`. This subcommand symmetric quantization since the beginning and incorrectly reporting the model to be symmetric will use GPTQ-Marlin (which does not support asymmetric quantization). --- .../test_flash_llama_gptq_marlin.json | 84 ----- ...st_flash_llama_gptq_marlin_all_params.json | 84 ----- .../test_flash_llama_gptq_marlin_load.json | 338 ------------------ .../models/test_flash_llama_gptq_marlin.py | 68 ---- .../layers/gptq/__init__.py | 10 + .../text_generation_server/layers/linear.py | 67 ++-- .../text_generation_server/layers/marlin.py | 15 + .../text_generation_server/utils/weights.py | 200 +++++------ 8 files changed, 144 insertions(+), 722 deletions(-) delete mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json delete mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json delete mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json delete mode 100644 integration-tests/models/test_flash_llama_gptq_marlin.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json deleted file mode 100644 index 0f99d259..00000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6230469, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.046875, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1425781, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.9238281, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.076660156, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10821533, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" -} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json deleted file mode 100644 index 4152b5b3..00000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": 0, - "tokens": [ - { - "id": 13, - "logprob": -2.2539062, - "special": false, - "text": "." - }, - { - "id": 578, - "logprob": -0.15563965, - "special": false, - "text": " The" - }, - { - "id": 3622, - "logprob": -0.8203125, - "special": false, - "text": " server" - }, - { - "id": 706, - "logprob": 0.0, - "special": false, - "text": " has" - }, - { - "id": 539, - "logprob": 0.0, - "special": false, - "text": " not" - }, - { - "id": 3686, - "logprob": 0.0, - "special": false, - "text": " yet" - }, - { - "id": 3288, - "logprob": 0.0, - "special": false, - "text": " sent" - }, - { - "id": 904, - "logprob": 0.0, - "special": false, - "text": " any" - }, - { - "id": 828, - "logprob": 0.0, - "special": false, - "text": " data" - }, - { - "id": 382, - "logprob": -1.5517578, - "special": false, - "text": ".\n\n" - } - ], - "top_tokens": null - }, - "generated_text": "Test request. The server has not yet sent any data.\n\n" -} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json deleted file mode 100644 index 75e90303..00000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json +++ /dev/null @@ -1,338 +0,0 @@ -[ - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - } -] diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py deleted file mode 100644 index 2274abce..00000000 --- a/integration-tests/models/test_flash_llama_gptq_marlin.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest - - -@pytest.fixture(scope="module") -def flash_llama_gptq_marlin_handle(launcher): - with launcher( - "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin" - ) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): - await flash_llama_gptq_marlin_handle.health(300) - return flash_llama_gptq_marlin_handle.client - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): - response = await flash_llama_gptq_marlin.generate( - "Test request", max_new_tokens=10, decoder_input_details=True - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_all_params( - flash_llama_gptq_marlin, response_snapshot -): - response = await flash_llama_gptq_marlin.generate( - "Test request", - max_new_tokens=10, - repetition_penalty=1.2, - return_full_text=True, - temperature=0.5, - top_p=0.9, - top_k=10, - truncate=5, - typical_p=0.9, - watermark=True, - decoder_input_details=True, - seed=0, - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_load( - flash_llama_gptq_marlin, generate_load, response_snapshot -): - responses = await generate_load( - flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4 - ) - - assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) - - assert responses == response_snapshot diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 1172775f..56080145 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import ( ) +@dataclass +class GPTQParams: + bits: int + checkpoint_format: Optional[str] + groupsize: int + desc_act: bool + quant_method: str + sym: bool + + @dataclass class GPTQWeight: qweight: torch.Tensor diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index dd48465f..e94e5465 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -166,35 +166,45 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + GPTQMarlinLinear, + GPTQMarlinWeight, + ) - if not isinstance(weight, GPTQWeight): + if isinstance(weight, GPTQMarlinWeight): + linear = GPTQMarlinLinear( + weight=weight, + bias=bias, + ) + elif isinstance(weight, GPTQWeight): + if weight.use_exllama: + try: + from text_generation_server.layers.gptq import ( + ExllamaQuantLinear, + ) + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + linear = ExllamaQuantLinear(weight, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + linear = QuantLinear( + weight.qweight, + weight.qzeros, + weight.scales, + weight.g_idx, + bias, + weight.bits, + weight.groupsize, + ) + else: raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - linear = ExllamaQuantLinear(weight, bias) - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, - ) elif quantize == "awq": from text_generation_server.layers.gptq import GPTQWeight @@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize): from text_generation_server.layers.marlin import ( GPTQMarlin24Linear, GPTQMarlin24Weight, - GPTQMarlinLinear, - GPTQMarlinWeight, MarlinLinear, MarlinWeight, ) - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, GPTQMarlin24Weight): + if isinstance(weight, GPTQMarlin24Weight): linear = GPTQMarlin24Linear( weight=weight, bias=bias, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 2207b2e4..a1af67a3 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -3,6 +3,8 @@ from typing import List, Optional, Tuple import torch import torch.nn as nn + +from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.import_utils import SYSTEM try: @@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 +def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: + return ( + SYSTEM == "cuda" + and marlin_kernels is not None + and has_sm_8_0 + and quantize == "gptq" + and gptq_params.quant_method == "gptq" + and gptq_params.bits in GPTQ_MARLIN_BITS + and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES + and gptq_params.sym + ) + + def _check_marlin_kernels(): if not (SYSTEM == "cuda" and has_sm_8_0): raise NotImplementedError( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 348d215c..3731fd24 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,25 +1,15 @@ import os -from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.log import log_once -@dataclass -class _GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool - - class Weights: def __init__( self, @@ -212,6 +202,10 @@ class Weights: """ if quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) try: qweight = self.get_packed_sharded( @@ -221,17 +215,28 @@ class Weights: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) - - gptq_params = self._get_gptq_params() - - qzeros = self.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) scales = self.get_packed_sharded( f"{prefix}.scales", dim=1, block_sizes=block_sizes ) scales = scales.to(dtype=self.dtype) + gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + g_idx = self.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) + + qzeros = self.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) if quantize == "gptq" and gptq_params.quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") elif quantize == "gptq" and gptq_params.quant_method == "awq": @@ -269,7 +274,6 @@ class Weights: repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: B = self.get_packed_sharded( @@ -286,31 +290,6 @@ class Weights: weight = GPTQMarlin24Weight( B=B, B_meta=B_meta, s=s, bits=gptq_params.bits ) - elif quant_method == "gptq": - gptq_params = self._get_gptq_params() - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - g_idx = self.get_tensor(f"{prefix}.g_idx") - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) else: B = self.get_packed_sharded( f"{prefix}.B", dim=1, block_sizes=block_sizes @@ -356,6 +335,10 @@ class Weights: raise ValueError("get_multi_weights_col is not supported for exl2") elif quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) try: qweight = torch.cat( @@ -366,14 +349,31 @@ class Weights: f"Cannot load `{quantize}` weight, make sure the model is already quantized" ) - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) scales = torch.cat( [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) from text_generation_server.layers.gptq import HAS_EXLLAMA @@ -425,10 +425,8 @@ class Weights: from text_generation_server.layers.marlin import ( GPTQMarlin24Weight, MarlinWeight, - repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: try: @@ -452,36 +450,6 @@ class Weights: weight = GPTQMarlin24Weight( B=B, B_meta=B_meta, s=s, bits=gptq_params.bits ) - elif quant_method == "gptq": - gptq_params = self._get_gptq_params() - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], - dim=1, - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) else: try: B = torch.cat( @@ -544,9 +512,41 @@ class Weights: ) elif quantize == "gptq": - use_exllama = True - gptq_params = self._get_gptq_params() + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + if gptq_params.desc_act or gptq_params.groupsize == -1: + scales = self.get_tensor(f"{prefix}.scales") + else: + scales = self.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = self.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True if gptq_params.bits != 4: use_exllama = False @@ -672,10 +672,8 @@ class Weights: from text_generation_server.layers.marlin import ( GPTQMarlin24Weight, MarlinWeight, - repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: try: @@ -698,35 +696,6 @@ class Weights: weight = GPTQMarlin24Weight( B=B, B_meta=B_meta, s=s, bits=gptq_params.bits ) - elif quant_method == "gptq": - log_once(logger.info, "Converting GPTQ model to Marlin packing format.") - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) else: try: B = self.get_sharded(f"{prefix}.B", dim=0) @@ -743,18 +712,17 @@ class Weights: else: s = self.get_sharded(f"{prefix}.s", dim=0) weight = MarlinWeight(B=B, s=s) - else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> _GPTQParams: + def _get_gptq_params(self) -> GPTQParams: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() checkpoint_format = getattr(self, "gptq_checkpoint_format", None) desc_act = False - sym = True + sym = False quant_method = "gptq" except (SafetensorError, RuntimeError) as e: try: @@ -767,7 +735,7 @@ class Weights: except Exception: raise e - return _GPTQParams( + return GPTQParams( bits=bits, checkpoint_format=checkpoint_format, desc_act=desc_act, From 5b977c31419878c3d3f57b4db281833b877d6c09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?icyboy=E2=84=A2?= Date: Mon, 1 Jul 2024 20:17:22 +0800 Subject: [PATCH 090/340] fix AttributeError: 'MixtralLayer' object has no attribute 'mlp' (#2123) https://github.com/huggingface/text-generation-inference/issues/2122 --- server/text_generation_server/models/flash_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 209eca83..0f5746de 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -153,7 +153,7 @@ class BaseFlashMistral(FlashCausalLM): # TODO: this is a hack to avoid the gate_proj for # FlashStarcoder2 that doesnt have these layers - if hasattr(layer.mlp, "gate_up_proj"): + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, From 6265956bc4020365fbaecc61e5923db364a5c5f4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 1 Jul 2024 20:32:54 +0800 Subject: [PATCH 091/340] refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel platform (#2132) * refine get xpu free memory Signed-off-by: Wang, Yi A * enable qwen2 in xpu Signed-off-by: Wang, Yi A * enable gemma/gemma2/phi in intel platform Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/ipex.py | 3 ++- server/text_generation_server/models/flash_gemma.py | 8 ++++++++ server/text_generation_server/models/flash_gemma2.py | 8 ++++++++ server/text_generation_server/models/flash_phi.py | 8 ++++++++ server/text_generation_server/models/flash_qwen2.py | 8 ++++++++ server/text_generation_server/utils/import_utils.py | 12 ++++++++---- 6 files changed, 42 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index bfab0119..7f086b68 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -14,6 +14,7 @@ def attention( max_s, softmax_scale, window_size_left=-1, + causal=True, ): # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( @@ -28,7 +29,7 @@ def attention( 0.0, softmax_scale, False, - True, + causal, False, None, ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index aa1ae9ac..7e2b8780 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma is only available on GPU") diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py index 9608113b..86cfc7e2 100644 --- a/server/text_generation_server/models/flash_gemma2.py +++ b/server/text_generation_server/models/flash_gemma2.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 7e108d05..a530d1c3 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashPhi is only available on GPU") diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 23528f0b..cd6078f1 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashQwen2 is only available on GPU") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 6d921721..011e0f63 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,6 +1,7 @@ import torch from loguru import logger import subprocess +import os def is_ipex_available(): @@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - query = f"xpu-smi dump -d {device_id} -m 18 -n 1" - output = subprocess.check_output(query.split()).decode("utf-8").split("\n") - used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 - free_memory = int(total_memory * 0.95 - used_memory) + memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + free_memory = max( + 0, + int( + total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) + ), + ) return free_memory From 381c5c02a64d84094ab94ec66e1b0b8d0cb29bb9 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 1 Jul 2024 09:08:05 -0400 Subject: [PATCH 092/340] fix: prefer serde structs over custom functions (#2127) * fix: prefer enum for chat object * fix: adjust typo * fix: enum CompletionType not ObjectType * fix: adjust typo * feat: leverage serde for conditional deser * fix: adjust HubTokenizerConfig after rebase * fix: update create_post_processor logic for token type * fix: adjust unwrap syntax in template * Fixing the post processor. --------- Co-authored-by: Nicolas Patry --- router/src/infer/mod.rs | 28 +++- router/src/lib.rs | 315 +++++++++++++++++++--------------------- router/src/main.rs | 29 ++-- router/src/server.rs | 38 ++--- 4 files changed, 207 insertions(+), 203 deletions(-) diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3..49282eb9 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, +}; +use crate::{ + FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -270,7 +272,11 @@ struct ChatTemplate { } impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + fn new( + template: String, + bos_token: Option, + eos_token: Option, + ) -> Self { let mut env = Box::new(Environment::new()); // enable things like .strip() or .capitalize() env.set_unknown_method_callback(pycompat::unknown_method_callback); @@ -287,8 +293,8 @@ impl ChatTemplate { Self { template, - bos_token, - eos_token, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, } } @@ -301,9 +307,9 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { + last_message.content.push(MessageChunk::Text { text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); + }); } } } @@ -340,6 +346,14 @@ impl ToolGrammar { .unwrap_or_else(|| panic!("Tool with name {} not found", name)) .clone()] } + ToolType::Function { function } => { + let tool = req_tools + .iter() + .find(|tool| tool.function.name == function.name) + .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) + .clone(); + vec![tool] + } ToolType::OneOf => req_tools.to_owned(), }; diff --git a/router/src/lib.rs b/router/src/lib.rs index a5b97af3..9ecfa051 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,23 +53,40 @@ pub enum ChatTemplateVersions { Multiple(Vec), } +use std::path::Path; + #[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub bos_token: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub eos_token: Option, + pub bos_token: Option, + pub eos_token: Option, pub tokenizer_class: Option, pub add_bos_token: Option, pub add_eos_token: Option, } impl HubTokenizerConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(untagged)] +pub enum TokenizerConfigToken { + String(String), + Object { content: String }, +} + +impl TokenizerConfigToken { + pub fn as_str(&self) -> &str { + match self { + TokenizerConfigToken::String(s) => s, + TokenizerConfigToken::Object { content } => content, + } } } @@ -100,9 +117,10 @@ pub struct HubProcessorConfig { } impl HubProcessorConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) } } @@ -121,35 +139,6 @@ pub(crate) enum GrammarType { Regex(String), } -mod token_serde { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(Some(s)), - Value::Object(map) => { - if let Some(content) = map.get("content").and_then(|v| v.as_str()) { - Ok(Some(content.to_string())) - } else { - Err(de::Error::custom( - "content key not found in structured token", - )) - } - } - Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), - } - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters { } } -mod prompt_serde { - use serde::{self, Deserialize, Deserializer}; - use serde_json::Value; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[serde(try_from = "PromptDeserializer")] +pub struct Prompt(pub Vec); - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Deserialize)] +#[serde(untagged)] +enum PromptDeserializer { + Single(String), + Multiple(Vec), +} + +impl TryFrom for Prompt { + type Error = String; + + fn try_from(value: PromptDeserializer) -> Result { match value { - Value::String(s) => Ok(vec![s]), - Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( - "Empty array detected. Do not use an empty array for the prompt.", - )), - Value::Array(arr) => arr - .iter() - .map(|v| match v { - Value::String(s) => Ok(s.to_owned()), - _ => Err(serde::de::Error::custom("Expected a string")), - }) - .collect(), - _ => Err(serde::de::Error::custom( - "Expected a string or an array of strings", - )), + PromptDeserializer::Single(s) => Ok(Prompt(vec![s])), + PromptDeserializer::Multiple(v) => { + if v.is_empty() { + Err( + "Empty array detected. Do not use an empty array for the prompt." + .to_string(), + ) + } else { + Ok(Prompt(v)) + } + } } } } @@ -396,8 +388,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] - #[serde(deserialize_with = "prompt_serde::deserialize")] - pub prompt: Vec, + pub prompt: Prompt, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] @@ -445,7 +436,6 @@ pub struct CompletionRequest { #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] pub(crate) struct Completion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -562,6 +551,15 @@ pub(crate) struct Usage { pub total_tokens: u32, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum CompletionType { + #[serde(rename = "chat.completion.chunk")] + ChatCompletionChunk(ChatCompletionChunk), + #[serde(rename = "chat.completion")] + ChatCompletion(ChatCompletion), +} + impl ChatCompletion { pub(crate) fn new( model: String, @@ -598,7 +596,6 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "chat.completion".into(), created, model, system_fingerprint, @@ -620,7 +617,6 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: String, pub created: u64, pub choices: Vec, pub model: String, @@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: String, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -710,7 +705,6 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "chat.completion.chunk".to_string(), created, model, system_fingerprint, @@ -821,7 +815,6 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, /// Response format constraints for the generation. @@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option { "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } -#[derive(Clone, Deserialize, ToSchema, Serialize)] -enum ToolType { - FunctionName(String), + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ToolType { OneOf, + FunctionName(String), + Function { function: FunctionName }, } -/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) -mod deserialize_tool_choice { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FunctionName { + pub name: String, +} - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(from = "ToolTypeDeserializer")] +pub struct ToolChoice(pub Option); +#[derive(Deserialize)] +#[serde(untagged)] +enum ToolTypeDeserializer { + None(Option), + Some(ToolType), +} + +impl From for ToolChoice { + fn from(value: ToolTypeDeserializer) -> Self { match value { - Value::String(s) => match s.as_str() { - "none" => Ok(None), - "auto" => Ok(Some(ToolType::OneOf)), - _ => Ok(Some(ToolType::FunctionName(s))), + ToolTypeDeserializer::None(opt) => match opt.as_deref() { + Some("none") => ToolChoice(None), + Some("auto") => ToolChoice(Some(ToolType::OneOf)), + Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), + None => ToolChoice(Some(ToolType::OneOf)), }, - Value::Object(map) => { - if let Some(content) = map - .get("function") - .and_then(|v| v.get("name")) - .and_then(|v| v.as_str()) - { - Ok(Some(ToolType::FunctionName(content.to_string()))) - } else { - Err(de::Error::custom("function key not found in tool choice")) - } - } - Value::Null => Ok(Some(ToolType::OneOf)), - _ => Err(de::Error::custom("invalid token format")), + ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -950,26 +940,16 @@ pub(crate) struct ToolCall { } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Url { +pub struct Url { url: String, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct ImageUrl { - image_url: Url, -} - -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Text { - text: String, -} - #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] -enum MessageChunk { - Text(Text), - ImageUrl(ImageUrl), +pub enum MessageChunk { + Text { text: String }, + ImageUrl { image_url: Url }, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -977,35 +957,31 @@ pub struct Message { #[schema(example = "user")] role: String, #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] - content: Vec, + pub content: MessageContent, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] name: Option, } -mod message_content_serde { - use super::*; - use serde::{Deserialize, Deserializer}; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + SingleText(String), + MultipleChunks(Vec), +} - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Message { - Text(String), - Chunks(Vec), - } - let message: Message = Deserialize::deserialize(deserializer)?; - let chunks = match message { - Message::Text(text) => { - vec![MessageChunk::Text(Text { text })] +// Pushing a chunk to a single text message will convert it to a multiple chunks message +impl MessageContent { + pub fn push(&mut self, chunk: MessageChunk) { + match self { + MessageContent::SingleText(text) => { + *self = + MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); } - Message::Chunks(s) => s, - }; - Ok(chunks) + MessageContent::MultipleChunks(chunks) => { + chunks.push(chunk); + } + } } } @@ -1021,18 +997,17 @@ impl From for TextMessage { fn from(value: Message) -> Self { TextMessage { role: value.role, - content: value - .content - .into_iter() - .map(|c| match c { - MessageChunk::Text(Text { text }) => text, - MessageChunk::ImageUrl(image) => { - let url = image.image_url.url; - format!("![]({url})") - } - }) - .collect::>() - .join(""), + content: match value.content { + MessageContent::SingleText(text) => text, + MessageContent::MultipleChunks(chunks) => chunks + .into_iter() + .map(|chunk| match chunk { + MessageChunk::Text { text } => text, + MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), + }) + .collect::>() + .join(""), + }, } } } @@ -1240,9 +1215,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::String( + "<|begin▁of▁sentence|>".to_string() + )) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::String( + "<|end▁of▁sentence|>".to_string() + )) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); // in this case we expect the tokens to be encoded as structured tokens // we want the content of the structured token @@ -1275,9 +1257,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::Object { + content: "<|begin▁of▁sentence|>".to_string() + }) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::Object { + content: "<|end▁of▁sentence|>".to_string() + }) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } #[test] @@ -1295,9 +1284,7 @@ mod tests { request.messages[0], Message { role: "user".to_string(), - content: vec![MessageChunk::Text(Text { - text: "What is Deep Learning?".to_string() - }),], + content: MessageContent::SingleText("What is Deep Learning?".to_string()), name: None } ); @@ -1321,10 +1308,10 @@ mod tests { request.messages[0], Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), name: None } ); @@ -1334,10 +1321,10 @@ mod tests { fn text_message_convert() { let message = Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), name: None }; let textmsg: TextMessage = message.into(); diff --git a/router/src/main.rs b/router/src/main.rs index 8a5cf459..8618f57e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -553,11 +553,11 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { let bos_token_id = tokenizer - .token_to_id(bos) + .token_to_id(bos.as_str()) .expect("Should have found the bos token id"); - special_tokens.push((bos.clone(), bos_token_id)); - single.push(format!("{}:0", bos)); - pair.push(format!("{}:0", bos)); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); } } @@ -567,17 +567,17 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { let eos_token_id = tokenizer - .token_to_id(eos) + .token_to_id(eos.as_str()) .expect("Should have found the eos token id"); - special_tokens.push((eos.clone(), eos_token_id)); - single.push(format!("{}:0", eos)); - pair.push(format!("{}:0", eos)); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); } } if add_bos_token { if let Some(bos) = bos_token { - pair.push(format!("{}:1", bos)); + pair.push(format!("{}:1", bos.as_str())); } } @@ -585,7 +585,7 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos)); + pair.push(format!("{}:1", eos.as_str())); } } @@ -611,14 +611,15 @@ enum RouterError { #[cfg(test)] mod tests { use super::*; + use text_generation_router::TokenizerConfigToken; #[test] fn test_create_post_processor() { let tokenizer_config = HubTokenizerConfig { add_bos_token: None, add_eos_token: None, - bos_token: Some("".to_string()), - eos_token: Some("".to_string()), + bos_token: Some(TokenizerConfigToken::String("".to_string())), + eos_token: Some(TokenizerConfigToken::String("".to_string())), chat_template: None, tokenizer_class: None, completion_template: None, @@ -629,9 +630,9 @@ mod tests { let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); let expected = TemplateProcessing::builder() - .try_single(":0 $A:0 :1") + .try_single(":0 $A:0") .unwrap() - .try_pair(":0 $A:0 $B:1") + .try_pair(":0 $A:0 :1 $B:1") .unwrap() .special_tokens(vec![("".to_string(), 1)]) .build() diff --git a/router/src/server.rs b/router/src/server.rs index 0cb08d4e..d24774f9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,17 +12,18 @@ use crate::kserve::{ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, - HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, - Token, TokenizeResponse, Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, + Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, + Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, + CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, + VertexResponse, }; -use crate::{FunctionDefinition, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -635,7 +636,7 @@ async fn completions( )); } - if req.prompt.len() > info.max_client_batch_size { + if req.prompt.0.len() > info.max_client_batch_size { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -651,6 +652,7 @@ async fn completions( let generate_requests: Vec = req .prompt + .0 .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), @@ -705,7 +707,6 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, choices: vec![CompletionComplete { @@ -932,7 +933,6 @@ async fn completions( let response = Completion { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, model: info.model_id.clone(), system_fingerprint: format!( @@ -1153,14 +1153,16 @@ async fn chat_completions( }; event - .json_data(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + .json_data(CompletionType::ChatCompletionChunk( + ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + stream_token.details.map(|d| d.finish_reason.to_string()), + ), )) .unwrap_or_else(|e| { println!("Failed to serialize ChatCompletionChunk: {:?}", e); @@ -1228,7 +1230,7 @@ async fn chat_completions( (None, Some(generation.generated_text)) }; // build the complete response object with the full text - let response = ChatCompletion::new( + let response = CompletionType::ChatCompletion(ChatCompletion::new( model_id, system_fingerprint, output, @@ -1236,7 +1238,7 @@ async fn chat_completions( generation.details.unwrap(), logprobs, tool_calls, - ); + )); // wrap generation inside a Vec to match api-inference Ok((headers, Json(response)).into_response()) From 2b9339c65bbec286f4fbb165afcdd55025cc3238 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 23:25:54 +0200 Subject: [PATCH 093/340] Fixing baichuan override. (#2158) --- .../models/custom_modeling/flash_llama_modeling.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6b82aeca..0ea9f623 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -117,6 +117,11 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + # Setting defaults for baichuan custom config which doesn't apply them. + config.rope_theta = getattr(config, "rope_theta", 10000) + config.num_key_value_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads + ) self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, From b80bd724e1b6f9c637ec2f724965fae9cf8cc8ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 23:28:00 +0200 Subject: [PATCH 094/340] Move to FlashDecoding instead of PagedAttention kernel. (#1940) * Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually. --- router/src/infer/v2/scheduler.rs | 9 +- router/src/infer/v3/scheduler.rs | 8 +- .../layers/attention/__init__.py | 2 + .../layers/attention/common.py | 44 ++++++ .../layers/attention/cuda.py | 139 ++++++++++++------ .../layers/attention/ipex.py | 5 +- .../layers/attention/rocm.py | 15 +- .../text_generation_server/models/__init__.py | 3 +- .../custom_modeling/flash_cohere_modeling.py | 6 +- .../custom_modeling/flash_dbrx_modeling.py | 2 +- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../custom_modeling/flash_gpt2_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 3 +- .../custom_modeling/flash_mistral_modeling.py | 5 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_phi_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 2 +- .../custom_modeling/flash_rw_modeling.py | 4 +- .../flash_santacoder_modeling.py | 2 +- .../flash_starcoder2_modeling.py | 2 +- .../models/flash_causal_lm.py | 25 +++- .../text_generation_server/models/globals.py | 8 +- 24 files changed, 223 insertions(+), 75 deletions(-) create mode 100644 server/text_generation_server/layers/attention/common.py diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index ba6f520d..e4c3de26 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -39,7 +39,14 @@ impl SchedulerV2 { speculate: u32, generation_health: Arc, ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); + // Infer shared state + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd83..543ce89f 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,9 +39,15 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, - 16, + block_size, window_size, speculate, max_batch_total_tokens, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e74180e7..c8bccefe 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,6 +1,8 @@ from text_generation_server.utils.import_utils import SYSTEM import os +from .common import Seqlen + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py new file mode 100644 index 00000000..bd0717ce --- /dev/null +++ b/server/text_generation_server/layers/attention/common.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from text_generation_server.models.globals import FLASH_DECODING +import torch +from typing import Optional + + +if FLASH_DECODING: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__(self, input_lengths): + self.input_lengths = input_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + +else: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + + def clamp(self, max): + return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 583337bd..94b69899 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -21,7 +23,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -32,7 +41,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -53,7 +62,8 @@ def paged_attention( # # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] + # block_size = value_cache.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE @@ -62,58 +72,95 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + if FLASH_DECODING: + max_q = 1 + max_k = max_s + import flash_attn_2_cuda - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + out2 = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, None, - "auto", - 1.0, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + False, # return softmax + None, # generator ) + return out2[0] else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + input_lengths = seqlen.input_lengths + from vllm._C import ops - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + return out try: diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 7f086b68..db79c589 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -55,7 +55,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( @@ -66,7 +67,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, BLOCK_SIZE, max_s, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 91ed5818..36db12d0 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,6 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -26,7 +27,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -37,7 +45,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -61,6 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -119,6 +129,7 @@ def paged_attention( "auto", 1.0, ) + return out if ENGINE != "triton": diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f2f0f457..5ea43290 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,6 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -92,6 +92,7 @@ except ImportError as e: FLASH_ATTENTION = False if FLASH_ATTENTION: + __all__.append(FlashCausalLM) __all__.append(FlashGPT2) __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 2850a6f3..e088f9aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - slots, input_lengths, + slots, max_s, ): qkv = self.query_key_value(hidden_states) @@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module): ) residual = None + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d56e4ef..aea7f399 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a71de61f..cfa6b2fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 82891823..842df0d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 7e7510c7..9f800146 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0ea9f623..77a7e2d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d1ba5564..69ed5f64 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( + Seqlen, paged_attention, attention, reshape_and_cache, @@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2e839d15..2d6a7f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b87fd4ca..33aebc2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 3f445f97..f237ea37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 69f38c3a..2e281386 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 04d4ba51..e7614232 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index badfc367..30989a37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index f6a2e15d..df864bc1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a0a78b33..49a088a1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -30,10 +30,13 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, + FLASH_DECODING, + BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, MODEL_ID, ) +from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments @@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = 16 # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -856,7 +858,23 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if SYSTEM == "ipex" and device == torch.device("cpu"): + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + elif SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( @@ -908,6 +926,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } + input_lengths = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1067,6 +1086,7 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + input_lengths = Seqlen(input_lengths=input_lengths) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1153,6 +1173,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index bde86e6e..06035ccd 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,12 @@ from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +if FLASH_DECODING: + logger.info("Using FLASH_DECODING") + + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: @@ -15,8 +21,6 @@ if cuda_graphs is not None: ) else: cuda_graphs = None - - # sorting the cuda graphs in descending order helps reduce the # memory impact and results in less memory usage if cuda_graphs is not None: From 9b3d3a3690f0a6b6339f493656df13533be6cf93 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 11:43:07 +0200 Subject: [PATCH 095/340] Fixing graph capture for flash decoding. (#2163) --- server/text_generation_server/models/flash_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 49a088a1..4f276ed4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -926,7 +926,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -939,7 +939,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -947,6 +947,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, From 71b0189cd51651fbbcdb46d70d17300c6e26d14a Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 2 Jul 2024 17:56:07 +0800 Subject: [PATCH 096/340] fix FlashDecoding change's regression in intel platform (#2161) install triton because GPTQParams needs it. Signed-off-by: Wang, Yi A --- Dockerfile_intel | 2 ++ server/text_generation_server/layers/attention/ipex.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index a41fbc1e..3c060f19 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl +RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server @@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install triton WORKDIR /usr/src diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index db79c589..45a0a03e 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,6 +1,7 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.models.flash_causal_lm import BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen SUPPORTS_WINDOWING = False @@ -55,11 +56,10 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + seqlen: Seqlen, max_s: int, ): - return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, key_cache, @@ -67,8 +67,9 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - cu_seqlen_q, + seqlen.input_lengths, BLOCK_SIZE, max_s, None, ) + return out From e913f3ad2d9253218608cef18ebe739b9c18a073 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 2 Jul 2024 05:56:25 -0400 Subject: [PATCH 097/340] fix: use the base layers weight in mistral rocm (#2155) --- .../models/custom_modeling/flash_mistral_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 69ed5f64..396969cd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -315,7 +315,9 @@ class MistralMLP(nn.Module): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + _custom_C.LLMM_Silu( + self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 + ) return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) From bc5a792dc88d8b709aa15b3e6b8abf79a6876b25 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 12:01:08 +0200 Subject: [PATCH 098/340] Fixing rocm. (#2164) --- server/text_generation_server/layers/attention/rocm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 36db12d0..99c490d5 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -2,6 +2,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.layers.attention import Seqlen from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -45,8 +46,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + input_lengths: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -70,7 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = cu_seqlen_k + input_lengths = input_lengths.input_lengths # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use From d580215a2418b7f6fb20fee018ba335525235c49 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 14:26:47 +0200 Subject: [PATCH 099/340] Hotfixing qwen2 and starcoder2 (which also get clamping). (#2167) --- .../models/custom_modeling/flash_qwen2_modeling.py | 2 +- .../models/custom_modeling/flash_starcoder2_modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 2e281386..1cc6a613 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index df864bc1..a0273c37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, From 233e46409a635d13194447d940f25ffbd937378f Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 3 Jul 2024 03:53:35 -0400 Subject: [PATCH 100/340] feat: improve update_docs for openapi schema (#2169) * feat: add pre commit step to force schema update when router changes * fix: prefer improved update_doc and start server and compare * fix: adjust typo * fix: adjust revert typo * fix: update workflow to use update_doc md command * feat: improve workflow to check openapi schema too * fix: adjust timeout for CI * fix: adjust raise condition and install server in ci * fix: install protoc before server * feat: improve update doc and add command to print router schema * fix: adjust autodoc workflow * fix: explicitly install protoc and python * fix: alllow trailing space in openapi schema diff --- docs/openapi.json | 165 ++++++++++++++++++++----------------------- router/src/main.rs | 23 +++++- router/src/server.rs | 15 ++-- update_doc.py | 50 +++++++++++++ 4 files changed, 158 insertions(+), 95 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 79c3b80f..7dc159a8 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.0.1" + "version": "2.1.1-dev0" }, "paths": { "/": { @@ -19,7 +19,6 @@ "Text Generation Inference" ], "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", - "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", "operationId": "compat_generate", "requestBody": { "content": { @@ -108,7 +107,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "generate", "requestBody": { "content": { @@ -192,7 +190,6 @@ "Text Generation Inference" ], "summary": "Generate a stream of token using Server-Sent Events", - "description": "Generate a stream of token using Server-Sent Events", "operationId": "generate_stream", "requestBody": { "content": { @@ -276,7 +273,6 @@ "Text Generation Inference" ], "summary": "Health check method", - "description": "Health check method", "operationId": "health", "responses": { "200": { @@ -305,7 +301,6 @@ "Text Generation Inference" ], "summary": "Text Generation Inference endpoint info", - "description": "Text Generation Inference endpoint info", "operationId": "get_model_info", "responses": { "200": { @@ -327,7 +322,6 @@ "Text Generation Inference" ], "summary": "Prometheus metrics scrape endpoint", - "description": "Prometheus metrics scrape endpoint", "operationId": "metrics", "responses": { "200": { @@ -349,7 +343,6 @@ "Text Generation Inference" ], "summary": "Tokenize inputs", - "description": "Tokenize inputs", "operationId": "tokenize", "requestBody": { "content": { @@ -394,7 +387,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "chat_completions", "requestBody": { "content": { @@ -483,7 +475,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "completions", "requestBody": { "content": { @@ -626,7 +617,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -653,9 +643,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" }, @@ -697,7 +684,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -723,9 +709,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" } @@ -756,34 +739,19 @@ "nullable": true }, "message": { - "$ref": "#/components/schemas/Message" + "$ref": "#/components/schemas/OutputMessage" } } }, "ChatCompletionDelta": { - "type": "object", - "required": [ - "role" - ], - "properties": { - "content": { - "type": "string", - "example": "What is Deep Learning?", - "nullable": true + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" }, - "role": { - "type": "string", - "example": "user" - }, - "tool_calls": { - "allOf": [ - { - "$ref": "#/components/schemas/DeltaToolCall" - } - ], - "nullable": true + { + "$ref": "#/components/schemas/ToolCallDelta" } - } + ] }, "ChatCompletionLogprob": { "type": "object", @@ -903,6 +871,15 @@ "example": 0.1, "nullable": true }, + "response_format": { + "allOf": [ + { + "$ref": "#/components/schemas/GrammarType" + } + ], + "default": "null", + "nullable": true + }, "seed": { "type": "integer", "format": "int64", @@ -1021,7 +998,6 @@ "type": "object", "required": [ "id", - "object", "created", "choices", "model", @@ -1045,9 +1021,6 @@ "model": { "type": "string" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" } @@ -1081,12 +1054,7 @@ "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "prompt": { - "type": "array", - "items": { - "type": "string" - }, - "description": "The prompt to generate completions for.", - "example": "What is Deep Learning?" + "$ref": "#/components/schemas/Prompt" }, "repetition_penalty": { "type": "number", @@ -1100,6 +1068,15 @@ "nullable": true, "minimum": 0 }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + }, "stream": { "type": "boolean" }, @@ -1121,15 +1098,6 @@ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "example": 0.95, "nullable": true - }, - "stop": { - "type": "array", - "items": { - "type": "string" - }, - "description": "Up to 4 sequences where the API will stop generating further tokens.", - "example": "null", - "nullable": true } } }, @@ -1272,8 +1240,16 @@ "GenerateParameters": { "type": "object", "properties": { + "adapter_id": { + "type": "string", + "description": "Lora adapter id", + "default": "null", + "example": "null", + "nullable": true + }, "best_of": { "type": "integer", + "description": "Generate best_of sequences and return the one if the highest token logprobs.", "default": "null", "example": 1, "nullable": true, @@ -1282,20 +1258,24 @@ }, "decoder_input_details": { "type": "boolean", + "description": "Whether to return decoder input token logprobs and ids.", "default": "false" }, "details": { "type": "boolean", + "description": "Whether to return generation details.", "default": "true" }, "do_sample": { "type": "boolean", + "description": "Activate logits sampling.", "default": "false", "example": true }, "frequency_penalty": { "type": "number", "format": "float", + "description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", "default": "null", "example": 0.1, "nullable": true, @@ -1313,6 +1293,7 @@ "max_new_tokens": { "type": "integer", "format": "int32", + "description": "Maximum number of tokens to generate.", "default": "100", "example": "20", "nullable": true, @@ -1321,6 +1302,7 @@ "repetition_penalty": { "type": "number", "format": "float", + "description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.", "default": "null", "example": 1.03, "nullable": true, @@ -1328,6 +1310,7 @@ }, "return_full_text": { "type": "boolean", + "description": "Whether to prepend the prompt to the generated text", "default": "null", "example": false, "nullable": true @@ -1335,6 +1318,7 @@ "seed": { "type": "integer", "format": "int64", + "description": "Random sampling seed.", "default": "null", "example": "null", "nullable": true, @@ -1346,6 +1330,7 @@ "items": { "type": "string" }, + "description": "Stop generating tokens if a member of `stop` is generated.", "example": [ "photographer" ], @@ -1354,6 +1339,7 @@ "temperature": { "type": "number", "format": "float", + "description": "The value used to module the logits distribution.", "default": "null", "example": 0.5, "nullable": true, @@ -1362,6 +1348,7 @@ "top_k": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", "default": "null", "example": 10, "nullable": true, @@ -1370,6 +1357,7 @@ "top_n_tokens": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.", "default": "null", "example": 5, "nullable": true, @@ -1379,6 +1367,7 @@ "top_p": { "type": "number", "format": "float", + "description": "Top-p value for nucleus sampling.", "default": "null", "example": 0.95, "nullable": true, @@ -1387,6 +1376,7 @@ }, "truncate": { "type": "integer", + "description": "Truncate inputs tokens to the given size.", "default": "null", "example": "null", "nullable": true, @@ -1395,6 +1385,7 @@ "typical_p": { "type": "number", "format": "float", + "description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.", "default": "null", "example": 0.95, "nullable": true, @@ -1403,6 +1394,7 @@ }, "watermark": { "type": "boolean", + "description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).", "default": "false", "example": true } @@ -1495,13 +1487,14 @@ "max_concurrent_requests", "max_best_of", "max_stop_sequences", - "max_input_length", + "max_input_tokens", "max_total_tokens", "waiting_served_ratio", "max_batch_total_tokens", "max_waiting_tokens", "validation_workers", "max_client_batch_size", + "router", "version" ], "properties": { @@ -1538,7 +1531,7 @@ "example": "128", "minimum": 0 }, - "max_input_length": { + "max_input_tokens": { "type": "integer", "example": "1024", "minimum": 0 @@ -1581,6 +1574,11 @@ "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", "nullable": true }, + "router": { + "type": "string", + "description": "Router Info", + "example": "text-generation-router" + }, "sha": { "type": "string", "example": "null", @@ -1593,7 +1591,6 @@ }, "version": { "type": "string", - "description": "Router Info", "example": "0.5.0" }, "waiting_served_ratio": { @@ -1606,13 +1603,12 @@ "Message": { "type": "object", "required": [ - "role" + "role", + "content" ], "properties": { "content": { - "type": "string", - "example": "My name is David and I", - "nullable": true + "$ref": "#/components/schemas/MessageContent" }, "name": { "type": "string", @@ -1622,13 +1618,6 @@ "role": { "type": "string", "example": "user" - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolCall" - }, - "nullable": true } } }, @@ -1817,9 +1806,7 @@ "$ref": "#/components/schemas/FunctionDefinition" }, "id": { - "type": "integer", - "format": "int32", - "minimum": 0 + "type": "string" }, "type": { "type": "string" @@ -1830,20 +1817,22 @@ "oneOf": [ { "type": "object", - "required": [ - "FunctionName" - ], - "properties": { - "FunctionName": { - "type": "string" - } - } + "default": null, + "nullable": true }, { - "type": "string", - "enum": [ - "OneOf" - ] + "type": "string" + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } } ] }, diff --git a/router/src/main.rs b/router/src/main.rs index 8618f57e..21cd6649 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,6 @@ use axum::http::HeaderValue; use clap::Parser; +use clap::Subcommand; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use opentelemetry::sdk::propagation::TraceContextPropagator; @@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + #[command(subcommand)] + command: Option, + #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, #[clap(default_value = "2", long, env)] @@ -85,10 +89,15 @@ struct Args { max_client_batch_size: usize, } +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + #[tokio::main] async fn main() -> Result<(), RouterError> { - // Get args let args = Args::parse(); + // Pattern match configuration let Args { max_concurrent_requests, @@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + command, } = args; - // Launch Tokio runtime - init_logging(otlp_endpoint, otlp_service_name, json_output); + let print_schema_command = match command { + Some(Commands::PrintSchema) => true, + None => { + // only init logging if we are not running the print schema command + init_logging(otlp_endpoint, otlp_service_name, json_output); + false + } + }; // Validate args if max_input_tokens >= max_total_tokens { @@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + print_schema_command, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index d24774f9..9be6a35c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1387,10 +1387,10 @@ async fn tokenize( /// Prometheus metrics scrape endpoint #[utoipa::path( -get, -tag = "Text Generation Inference", -path = "/metrics", -responses((status = 200, description = "Prometheus Metrics", body = String)) + get, + tag = "Text Generation Inference", + path = "/metrics", + responses((status = 200, description = "Prometheus Metrics", body = String)) )] async fn metrics(prom_handle: Extension) -> String { prom_handle.render() @@ -1430,6 +1430,7 @@ pub async fn run( messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, + print_schema_command: bool, ) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1500,6 +1501,12 @@ pub async fn run( struct ApiDoc; // Create state + if print_schema_command { + let api_doc = ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + } // Open connection, get model info and warmup let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( diff --git a/update_doc.py b/update_doc.py index 5da81c72..1ff94a2c 100644 --- a/update_doc.py +++ b/update_doc.py @@ -1,6 +1,8 @@ import subprocess import argparse import ast +import json +import os TEMPLATE = """ # Supported Models and Hardware @@ -122,6 +124,53 @@ def check_supported_models(check: bool): f.write(final_doc) +def get_openapi_schema(): + try: + output = subprocess.check_output(["text-generation-router", "print-schema"]) + return json.loads(output) + except subprocess.CalledProcessError as e: + print(f"Error running text-generation-router print-schema: {e}") + raise SystemExit(1) + except json.JSONDecodeError: + print("Error: Invalid JSON received from text-generation-router print-schema") + raise SystemExit(1) + + +def check_openapi(check: bool): + new_openapi_data = get_openapi_schema() + filename = "docs/openapi.json" + tmp_filename = "openapi_tmp.json" + + with open(tmp_filename, "w") as f: + json.dump(new_openapi_data, f, indent=2) + + if check: + diff = subprocess.run( + [ + "diff", + # allow for trailing whitespace since it's not significant + # and the precommit hook will remove it + "--ignore-trailing-space", + tmp_filename, + filename, + ], + capture_output=True, + ).stdout.decode() + os.remove(tmp_filename) + + if diff: + print(diff) + raise Exception( + "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" + ) + + return True + else: + os.rename(tmp_filename, filename) + print("OpenAPI documentation updated.") + return True + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--check", action="store_true") @@ -130,6 +179,7 @@ def main(): check_cli(args.check) check_supported_models(args.check) + check_openapi(args.check) if __name__ == "__main__": From b6c8984658e04617f1fbdb3fbe392b3215a67804 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 10:40:22 +0000 Subject: [PATCH 101/340] Fixing missing `object` field for regular completions. --- docs/openapi.json | 95 ++++++++++++++++++++++++++++++++++++++++++-- router/src/lib.rs | 29 +++++++++----- router/src/server.rs | 14 ++++--- 3 files changed, 118 insertions(+), 20 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 7dc159a8..b2002efe 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -946,6 +946,38 @@ } } }, + "Chunk": { + "type": "object", + "required": [ + "id", + "created", + "choices", + "model", + "system_fingerprint" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, "CompatGenerateRequest": { "type": "object", "required": [ @@ -965,6 +997,55 @@ } } }, + "Completion": { + "oneOf": [ + { + "allOf": [ + { + "$ref": "#/components/schemas/Chunk" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/CompletionFinal" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + } + ], + "discriminator": { + "propertyName": "object" + } + }, "CompletionComplete": { "type": "object", "required": [ @@ -994,14 +1075,15 @@ } } }, - "CompletionCompleteChunk": { + "CompletionFinal": { "type": "object", "required": [ "id", "created", - "choices", "model", - "system_fingerprint" + "system_fingerprint", + "choices", + "usage" ], "properties": { "choices": { @@ -1013,16 +1095,21 @@ "created": { "type": "integer", "format": "int64", + "example": "1706270835", "minimum": 0 }, "id": { "type": "string" }, "model": { - "type": "string" + "type": "string", + "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "system_fingerprint": { "type": "string" + }, + "usage": { + "$ref": "#/components/schemas/Usage" } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 9ecfa051..165b2ad2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -433,8 +433,17 @@ pub struct CompletionRequest { pub stop: Option>, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum Completion { + #[serde(rename = "text_completion")] + Chunk(Chunk), + #[serde(rename = "text_completion")] + Final(CompletionFinal), +} + #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] -pub(crate) struct Completion { +pub(crate) struct CompletionFinal { pub id: String, #[schema(example = "1706270835")] pub created: u64, @@ -453,6 +462,15 @@ pub(crate) struct CompletionComplete { pub finish_reason: String, } +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct Chunk { + pub id: String, + pub created: u64, + pub choices: Vec, + pub model: String, + pub system_fingerprint: String, +} + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, @@ -614,15 +632,6 @@ impl ChatCompletion { } } } -#[derive(Clone, Deserialize, Serialize, ToSchema)] -pub(crate) struct CompletionCompleteChunk { - pub id: String, - pub created: u64, - pub choices: Vec, - pub model: String, - pub system_fingerprint: String, -} - #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index 9be6a35c..ce0d6675 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,7 +19,7 @@ use crate::{ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, - ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, + ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; @@ -705,7 +705,7 @@ async fn completions( .as_secs(); event - .json_data(CompletionCompleteChunk { + .json_data(Completion::Chunk(Chunk { id: "".to_string(), created: current_time, @@ -718,7 +718,7 @@ async fn completions( model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - }) + })) .unwrap_or_else(|_e| Event::default()) }; @@ -931,7 +931,7 @@ async fn completions( .collect::, _>>() .map_err(|(status, Json(err))| (status, Json(err)))?; - let response = Completion { + let response = Completion::Final(CompletionFinal { id: "".to_string(), created: current_time, model: info.model_id.clone(), @@ -946,7 +946,7 @@ async fn completions( completion_tokens, total_tokens, }, - }; + }); // headers similar to `generate` but aggregated let mut headers = HeaderMap::new(); @@ -1464,7 +1464,9 @@ pub async fn run( ChatCompletion, CompletionRequest, CompletionComplete, - CompletionCompleteChunk, + Chunk, + Completion, + CompletionFinal, GenerateParameters, PrefillToken, Token, From 878491cd5b6ff0afb8b2c1c5c56c568acc223f9e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 10:41:39 +0000 Subject: [PATCH 102/340] Revert "Fixing missing `object` field for regular completions." This reverts commit 2bbb7fa4b266cc1216e63f2047569fe1c8a0ad2a. --- docs/openapi.json | 95 ++------------------------------------------ router/src/lib.rs | 29 +++++--------- router/src/server.rs | 14 +++---- 3 files changed, 20 insertions(+), 118 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index b2002efe..7dc159a8 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -946,38 +946,6 @@ } } }, - "Chunk": { - "type": "object", - "required": [ - "id", - "created", - "choices", - "model", - "system_fingerprint" - ], - "properties": { - "choices": { - "type": "array", - "items": { - "$ref": "#/components/schemas/CompletionComplete" - } - }, - "created": { - "type": "integer", - "format": "int64", - "minimum": 0 - }, - "id": { - "type": "string" - }, - "model": { - "type": "string" - }, - "system_fingerprint": { - "type": "string" - } - } - }, "CompatGenerateRequest": { "type": "object", "required": [ @@ -997,55 +965,6 @@ } } }, - "Completion": { - "oneOf": [ - { - "allOf": [ - { - "$ref": "#/components/schemas/Chunk" - }, - { - "type": "object", - "required": [ - "object" - ], - "properties": { - "object": { - "type": "string", - "enum": [ - "text_completion" - ] - } - } - } - ] - }, - { - "allOf": [ - { - "$ref": "#/components/schemas/CompletionFinal" - }, - { - "type": "object", - "required": [ - "object" - ], - "properties": { - "object": { - "type": "string", - "enum": [ - "text_completion" - ] - } - } - } - ] - } - ], - "discriminator": { - "propertyName": "object" - } - }, "CompletionComplete": { "type": "object", "required": [ @@ -1075,15 +994,14 @@ } } }, - "CompletionFinal": { + "CompletionCompleteChunk": { "type": "object", "required": [ "id", "created", - "model", - "system_fingerprint", "choices", - "usage" + "model", + "system_fingerprint" ], "properties": { "choices": { @@ -1095,21 +1013,16 @@ "created": { "type": "integer", "format": "int64", - "example": "1706270835", "minimum": 0 }, "id": { "type": "string" }, "model": { - "type": "string", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "type": "string" }, "system_fingerprint": { "type": "string" - }, - "usage": { - "$ref": "#/components/schemas/Usage" } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 165b2ad2..9ecfa051 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -433,17 +433,8 @@ pub struct CompletionRequest { pub stop: Option>, } -#[derive(Clone, Serialize, ToSchema)] -#[serde(tag = "object")] -enum Completion { - #[serde(rename = "text_completion")] - Chunk(Chunk), - #[serde(rename = "text_completion")] - Final(CompletionFinal), -} - #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] -pub(crate) struct CompletionFinal { +pub(crate) struct Completion { pub id: String, #[schema(example = "1706270835")] pub created: u64, @@ -462,15 +453,6 @@ pub(crate) struct CompletionComplete { pub finish_reason: String, } -#[derive(Clone, Deserialize, Serialize, ToSchema)] -pub(crate) struct Chunk { - pub id: String, - pub created: u64, - pub choices: Vec, - pub model: String, - pub system_fingerprint: String, -} - #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, @@ -632,6 +614,15 @@ impl ChatCompletion { } } } +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct CompletionCompleteChunk { + pub id: String, + pub created: u64, + pub choices: Vec, + pub model: String, + pub system_fingerprint: String, +} + #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index ce0d6675..9be6a35c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,7 +19,7 @@ use crate::{ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, - ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, + ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; @@ -705,7 +705,7 @@ async fn completions( .as_secs(); event - .json_data(Completion::Chunk(Chunk { + .json_data(CompletionCompleteChunk { id: "".to_string(), created: current_time, @@ -718,7 +718,7 @@ async fn completions( model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - })) + }) .unwrap_or_else(|_e| Event::default()) }; @@ -931,7 +931,7 @@ async fn completions( .collect::, _>>() .map_err(|(status, Json(err))| (status, Json(err)))?; - let response = Completion::Final(CompletionFinal { + let response = Completion { id: "".to_string(), created: current_time, model: info.model_id.clone(), @@ -946,7 +946,7 @@ async fn completions( completion_tokens, total_tokens, }, - }); + }; // headers similar to `generate` but aggregated let mut headers = HeaderMap::new(); @@ -1464,9 +1464,7 @@ pub async fn run( ChatCompletion, CompletionRequest, CompletionComplete, - Chunk, - Completion, - CompletionFinal, + CompletionCompleteChunk, GenerateParameters, PrefillToken, Token, From 64989f94390af5c2ab5bae1548ed1cb5a42c0423 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 12:48:45 +0200 Subject: [PATCH 103/340] Fixing the dockerfile warnings. (#2173) --- Dockerfile_amd | 16 ++++++++-------- Dockerfile_intel | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index e325bb65..0aebeee5 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -4,7 +4,7 @@ WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse -FROM chef as planner +FROM chef AS planner COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -37,7 +37,7 @@ COPY launcher launcher RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base +FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -115,7 +115,7 @@ ARG BUILD_CAFFE2="0" \ RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install -# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. @@ -143,26 +143,26 @@ COPY server/Makefile-flash-att-v2 Makefile RUN make build-flash-attention-v2-rocm # Build Transformers CUDA kernels (gpt-neox and bloom) -FROM kernel-builder as custom-kernels-builder +FROM kernel-builder AS custom-kernels-builder WORKDIR /usr/src COPY server/custom_kernels/ . RUN python setup.py build # Build exllama kernels -FROM kernel-builder as exllama-kernels-builder +FROM kernel-builder AS exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . RUN python setup.py build # Build exllama v2 kernels -FROM kernel-builder as exllamav2-kernels-builder +FROM kernel-builder AS exllamav2-kernels-builder WORKDIR /usr/src COPY server/exllamav2_kernels/ . RUN python setup.py build -FROM base as base-copy +FROM base AS base-copy # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ @@ -201,7 +201,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher # AWS Sagemaker compatible image -FROM base as sagemaker +FROM base AS sagemaker COPY sagemaker-entrypoint.sh entrypoint.sh RUN chmod +x entrypoint.sh diff --git a/Dockerfile_intel b/Dockerfile_intel index 3c060f19..6a803a32 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -5,7 +5,7 @@ WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse -FROM chef as planner +FROM chef AS planner COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -40,7 +40,7 @@ RUN cargo build --profile release-opt # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu +FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu USER root # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it @@ -95,7 +95,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo # Text Generation Inference base image for Intel-cpu -FROM ubuntu:22.04 as cpu +FROM ubuntu:22.04 AS cpu RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ curl \ @@ -172,6 +172,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher -FROM ${PLATFORM} as final +FROM ${PLATFORM} AS final ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] From e93c830e6637f9a694654df1e20e2f9facf078de Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 12:56:27 +0200 Subject: [PATCH 104/340] Fixing missing `object` field for regular completions. (#2175) * Fixing missing `object` field for regular completions. * Fixing docs by re-adding missing `Prompt`. --- docs/openapi.json | 101 +++++++++++++++++++++++++++++++++++++++++-- router/src/lib.rs | 29 ++++++++----- router/src/server.rs | 17 +++++--- 3 files changed, 126 insertions(+), 21 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 7dc159a8..5e0399e0 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -946,6 +946,38 @@ } } }, + "Chunk": { + "type": "object", + "required": [ + "id", + "created", + "choices", + "model", + "system_fingerprint" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, "CompatGenerateRequest": { "type": "object", "required": [ @@ -965,6 +997,55 @@ } } }, + "Completion": { + "oneOf": [ + { + "allOf": [ + { + "$ref": "#/components/schemas/Chunk" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/CompletionFinal" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + } + ], + "discriminator": { + "propertyName": "object" + } + }, "CompletionComplete": { "type": "object", "required": [ @@ -994,14 +1075,15 @@ } } }, - "CompletionCompleteChunk": { + "CompletionFinal": { "type": "object", "required": [ "id", "created", - "choices", "model", - "system_fingerprint" + "system_fingerprint", + "choices", + "usage" ], "properties": { "choices": { @@ -1013,16 +1095,21 @@ "created": { "type": "integer", "format": "int64", + "example": "1706270835", "minimum": 0 }, "id": { "type": "string" }, "model": { - "type": "string" + "type": "string", + "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "system_fingerprint": { "type": "string" + }, + "usage": { + "$ref": "#/components/schemas/Usage" } } }, @@ -1647,6 +1734,12 @@ } } }, + "Prompt": { + "type": "array", + "items": { + "type": "string" + } + }, "SimpleToken": { "type": "object", "required": [ diff --git a/router/src/lib.rs b/router/src/lib.rs index 9ecfa051..165b2ad2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -433,8 +433,17 @@ pub struct CompletionRequest { pub stop: Option>, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum Completion { + #[serde(rename = "text_completion")] + Chunk(Chunk), + #[serde(rename = "text_completion")] + Final(CompletionFinal), +} + #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] -pub(crate) struct Completion { +pub(crate) struct CompletionFinal { pub id: String, #[schema(example = "1706270835")] pub created: u64, @@ -453,6 +462,15 @@ pub(crate) struct CompletionComplete { pub finish_reason: String, } +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct Chunk { + pub id: String, + pub created: u64, + pub choices: Vec, + pub model: String, + pub system_fingerprint: String, +} + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, @@ -614,15 +632,6 @@ impl ChatCompletion { } } } -#[derive(Clone, Deserialize, Serialize, ToSchema)] -pub(crate) struct CompletionCompleteChunk { - pub id: String, - pub created: u64, - pub choices: Vec, - pub model: String, - pub system_fingerprint: String, -} - #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index 9be6a35c..db8b16ad 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,8 +19,8 @@ use crate::{ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, - ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, + ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, + CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; @@ -705,7 +705,7 @@ async fn completions( .as_secs(); event - .json_data(CompletionCompleteChunk { + .json_data(Completion::Chunk(Chunk { id: "".to_string(), created: current_time, @@ -718,7 +718,7 @@ async fn completions( model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - }) + })) .unwrap_or_else(|_e| Event::default()) }; @@ -931,7 +931,7 @@ async fn completions( .collect::, _>>() .map_err(|(status, Json(err))| (status, Json(err)))?; - let response = Completion { + let response = Completion::Final(CompletionFinal { id: "".to_string(), created: current_time, model: info.model_id.clone(), @@ -946,7 +946,7 @@ async fn completions( completion_tokens, total_tokens, }, - }; + }); // headers similar to `generate` but aggregated let mut headers = HeaderMap::new(); @@ -1464,7 +1464,10 @@ pub async fn run( ChatCompletion, CompletionRequest, CompletionComplete, - CompletionCompleteChunk, + Chunk, + Completion, + CompletionFinal, + Prompt, GenerateParameters, PrefillToken, Token, From 74ddd1265a1799e995c4a7af27b0467fae73652b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 12:39:07 +0200 Subject: [PATCH 105/340] Version 2.1.1 --- docs/openapi.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/openapi.json b/docs/openapi.json index 5e0399e0..03e2d4ff 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.1.1-dev0" + "version": "2.1.1" }, "paths": { "/": { From 2e09ebecf6be424cebf3324f297946d8545af957 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 10:55:33 +0200 Subject: [PATCH 106/340] Preparing patch release. (#2186) --- docs/source/installation_amd.md | 2 +- docs/source/installation_nvidia.md | 2 +- docs/source/quicktour.md | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index fe925e2a..33d85732 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \ + ghcr.io/huggingface/text-generation-inference:2.1.1-rocm \ --model-id $model ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 11c41763..4de6cb19 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.0 \ + ghcr.io/huggingface/text-generation-inference:2.1.1 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 09e56df4..c546bc03 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.0 \ + ghcr.io/huggingface/text-generation-inference:2.1.1 \ --model-id $model ``` @@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help +docker run ghcr.io/huggingface/text-generation-inference:2.1.1 --help ``` From 835ad0a923e6f4b926f6fd4480ba3d14ba8cfe91 Mon Sep 17 00:00:00 2001 From: Aaron Mihalik Date: Fri, 5 Jul 2024 03:46:41 -0400 Subject: [PATCH 107/340] Adding "longrope" for Phi-3 (#2172) (#2179) Adding "longrope" for phi-3 --- server/text_generation_server/layers/rotary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index b14005e6..87a61e82 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module): beta_fast=32, beta_slow=1, ) - elif rope_scaling["type"] == "su": + elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) From 1b434e801942f9eb833f65e7eaa6ae233016ac06 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Jul 2024 10:29:56 +0200 Subject: [PATCH 108/340] Refactor dead code - Removing all `flash_xxx.py` files. (#2166) * Refactor dead code. * First working step. * Remove a lot of duplicated code. * More dead code. * More cleanup. * Fix Santacoder test. * Fixing the simple tests. * Fixing sharding. * Fixes for VLM. * Fixing santacoder (num_kv_heads hardcoded). * Removing more dead code. * Fixing `config.n_head`. * Stopping earlier because of `` in idefics2. * Addresses comments. * Removing the dead code. * Fuse back mistral into FlashCausalLM. * Finish removal. * Fixing docs + causal_lm `batch_class`. * Fixing docs + causal.lm. * Add default to Gemma Causality. * Default value for gemma/gemma2. * Wrong default. --- docs/source/supported_models.md | 1 + .../test_flash_idefics2_two_images.json | 48 +-- integration-tests/models/test_idefics2.py | 2 +- server/tests/models/test_bloom.py | 8 +- server/tests/models/test_causal_lm.py | 2 +- server/tests/models/test_santacoder.py | 5 +- server/tests/models/test_seq2seq_lm.py | 2 +- .../text_generation_server/models/__init__.py | 393 +++++++++++------- server/text_generation_server/models/bloom.py | 73 ---- .../models/causal_lm.py | 106 ++++- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../flash_santacoder_modeling.py | 3 +- .../models/custom_modeling/llava_next.py | 8 +- .../models/flash_causal_lm.py | 193 ++++++++- .../models/flash_cohere.py | 75 ---- .../models/flash_dbrx.py | 100 ----- .../models/flash_gemma.py | 83 ---- .../models/flash_gemma2.py | 83 ---- .../models/flash_gpt2.py | 82 ---- .../models/flash_llama.py | 171 -------- .../models/flash_mistral.py | 126 +----- .../models/flash_mixtral.py | 31 -- .../models/flash_neox.py | 82 ---- .../models/flash_phi.py | 111 ----- .../models/flash_qwen2.py | 93 ----- .../text_generation_server/models/flash_rw.py | 91 ---- .../models/flash_santacoder.py | 99 ----- .../models/flash_starcoder2.py | 84 ---- .../models/galactica.py | 80 ---- .../text_generation_server/models/gpt_neox.py | 89 ---- .../text_generation_server/models/idefics2.py | 51 --- .../models/llava_next.py | 46 -- server/text_generation_server/models/mpt.py | 105 ----- server/text_generation_server/models/opt.py | 86 ---- .../models/pali_gemma.py | 42 -- server/text_generation_server/models/phi.py | 69 --- server/text_generation_server/models/rw.py | 84 ---- .../models/santacoder.py | 77 ---- .../models/seq2seq_lm.py | 99 ++++- server/text_generation_server/models/t5.py | 115 ----- .../models/vlm_causal_lm.py | 36 +- 42 files changed, 688 insertions(+), 2450 deletions(-) delete mode 100644 server/text_generation_server/models/flash_cohere.py delete mode 100644 server/text_generation_server/models/flash_dbrx.py delete mode 100644 server/text_generation_server/models/flash_gemma.py delete mode 100644 server/text_generation_server/models/flash_gemma2.py delete mode 100644 server/text_generation_server/models/flash_gpt2.py delete mode 100644 server/text_generation_server/models/flash_llama.py delete mode 100644 server/text_generation_server/models/flash_mixtral.py delete mode 100644 server/text_generation_server/models/flash_neox.py delete mode 100644 server/text_generation_server/models/flash_phi.py delete mode 100644 server/text_generation_server/models/flash_qwen2.py delete mode 100644 server/text_generation_server/models/flash_rw.py delete mode 100644 server/text_generation_server/models/flash_santacoder.py delete mode 100644 server/text_generation_server/models/flash_starcoder2.py delete mode 100644 server/text_generation_server/models/gpt_neox.py delete mode 100644 server/text_generation_server/models/idefics2.py delete mode 100644 server/text_generation_server/models/llava_next.py delete mode 100644 server/text_generation_server/models/mpt.py delete mode 100644 server/text_generation_server/models/opt.py delete mode 100644 server/text_generation_server/models/phi.py delete mode 100644 server/text_generation_server/models/rw.py delete mode 100644 server/text_generation_server/models/santacoder.py delete mode 100644 server/text_generation_server/models/t5.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 1eeed39f..2bdd00de 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) +- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [Gemma2](https://huggingface.co/google/gemma2-9b) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json index bf2dc5a1..44ccea71 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -1,130 +1,124 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 19, "prefill": [], "seed": null, "tokens": [ { "id": 415, - "logprob": -0.039886475, + "logprob": -0.03665161, "special": false, "text": " The" }, { "id": 12072, - "logprob": -0.1430664, + "logprob": -0.13549805, "special": false, "text": " cow" }, { "id": 349, - "logprob": -0.056488037, + "logprob": -0.05819702, "special": false, "text": " is" }, { "id": 6328, - "logprob": -0.6855469, + "logprob": -0.6826172, "special": false, "text": " standing" }, { "id": 356, - "logprob": -0.1685791, + "logprob": -0.1607666, "special": false, "text": " on" }, { "id": 272, - "logprob": -0.50097656, + "logprob": -0.5073242, "special": false, "text": " the" }, { "id": 10305, - "logprob": -0.017303467, + "logprob": -0.016418457, "special": false, "text": " beach" }, { "id": 304, - "logprob": -1.3564453, + "logprob": -1.3916016, "special": false, "text": " and" }, { "id": 272, - "logprob": -0.017868042, + "logprob": -0.020217896, "special": false, "text": " the" }, { "id": 13088, - "logprob": -0.0027103424, + "logprob": -0.0028133392, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.003156662, + "logprob": -0.003145218, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.37304688, + "logprob": -0.37060547, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.034576416, + "logprob": -0.034851074, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.29418945, + "logprob": -0.2878418, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.042877197, + "logprob": -0.046051025, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.00028443336, + "logprob": -0.00028848648, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.023223877, + "logprob": -0.025772095, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.018157959, + "logprob": -0.018127441, "special": false, "text": "." }, { "id": 32002, - "logprob": -0.00018393993, + "logprob": -0.00019824505, "special": true, "text": "" - }, - { - "id": 2, - "logprob": -1.1920929e-07, - "special": true, - "text": "" } ], "top_tokens": null diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index 9aaf6d8a..c5f48da3 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot) response.generated_text == " The cow is standing on the beach and the chicken is sitting on a pile of money." ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 19 assert response == response_snapshot diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686..08292920 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) @pytest.fixture(scope="session") @@ -16,7 +19,10 @@ def default_bloom(): revision = "main" filenames = weight_hub_files(model_id, revision, ".safetensors") download_weights(filenames, model_id, revision) - return BLOOMSharded(model_id) + return BLOOMSharded( + model_id, + model_class=BloomForCausalLM, + ) @pytest.fixture(scope="session") diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc..c000ef26 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture(scope="session") def default_causal_lm(): - return CausalLM("gpt2") + return CausalLM.fallback("gpt2") @pytest.fixture(scope="session") diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index cb2622d9..d5c91bff 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,13 +1,12 @@ import pytest from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM @pytest.fixture(scope="session") def default_santacoder(): - return SantaCoder("bigcode/santacoder") + return CausalLM.fallback(model_id="bigcode/santacoder") @pytest.fixture diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b08..02666042 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -20,7 +20,7 @@ def mt0_small_tokenizer(): @pytest.fixture(scope="session") def default_seq2seq_lm(): - return Seq2SeqLM("bigscience/mt0-small") + return Seq2SeqLM.fallback("bigscience/mt0-small") @pytest.fixture diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5ea43290..15e74622 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,17 +11,26 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.bloom import BLOOMSharded -from text_generation_server.models.mpt import MPTSharded +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM +from text_generation_server.models.custom_modeling.mpt_modeling import ( + MPTForCausalLM, +) +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.rw import RW -from text_generation_server.models.opt import OPTSharded -from text_generation_server.models.galactica import GalacticaSharded -from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.t5 import T5Sharded -from text_generation_server.models.gpt_neox import GPTNeoxSharded -from text_generation_server.models.phi import Phi +from text_generation_server.models.galactica import GalacticaCausalLMBatch +from text_generation_server.models.custom_modeling.neox_modeling import ( + GPTNeoxForCausalLM, +) +from text_generation_server.models.custom_modeling.phi_modeling import ( + PhiConfig, + PhiForCausalLM, +) +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils.import_utils import SYSTEM @@ -41,9 +50,6 @@ __all__ = [ "CausalLM", "GalacticaSharded", "Seq2SeqLM", - "SantaCoder", - "OPTSharded", - "T5Sharded", "get_model", ] @@ -53,38 +59,65 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_gpt2 import FlashGPT2 - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, ) - from text_generation_server.models.flash_qwen2 import ( - FlashQwen2, + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, ) - from text_generation_server.models.flash_cohere import ( - FlashCohere, + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, ) - from text_generation_server.models.flash_gemma import ( - FlashGemma, + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, ) - from text_generation_server.models.flash_gemma2 import ( - FlashGemma2, + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, ) from text_generation_server.models.pali_gemma import ( - PaliGemma, + PaliGemmaBatch, ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, ) from text_generation_server.models.idefics import IDEFICSSharded - from text_generation_server.models.llava_next import LlavaNext - from text_generation_server.models.idefics2 import Idefics2 - from text_generation_server.models.flash_mistral import FlashMistral - from text_generation_server.models.flash_mixtral import FlashMixtral - from text_generation_server.models.flash_phi import FlashPhi - from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 - from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") @@ -93,21 +126,7 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(FlashGPT2) - __all__.append(FlashNeoXSharded) - __all__.append(FlashRWSharded) - __all__.append(FlashSantacoderSharded) - __all__.append(FlashLlama) __all__.append(IDEFICSSharded) - __all__.append(FlashMistral) - __all__.append(FlashMixtral) - __all__.append(FlashDbrx) - __all__.append(FlashPhi) - __all__.append(FlashQwen2) - __all__.append(FlashStarcoder2) - __all__.append(FlashGemma) - __all__.append(FlashGemma2) - __all__.append(FlashCohere) MAMBA_AVAILABLE = True try: @@ -148,6 +167,11 @@ class ModelType(enum.Enum): "name": "Gemma", "url": "https://huggingface.co/google/gemma-7b", } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } GEMMA2 = { "type": "gemma2", "name": "Gemma2", @@ -445,13 +469,16 @@ def get_model( ) if model_id.startswith("facebook/galactica"): - return GalacticaSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + # Yes galactica is just an OPT model. + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=GalacticaCausalLMBatch, ) if ( @@ -460,22 +487,26 @@ def get_model( and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return SantaCoder( - model_id, - revision, + return CausalLM.fallback( + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -483,38 +514,44 @@ def get_model( ) if model_type == BLOOM: - return BLOOMSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=BloomForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == MPT: - return MPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=MPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: if FLASH_ATTENTION: try: - return FlashGPT2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) except RuntimeError as e: # Lots of legacy models with various weight names. logger.warning(f"Couldn't load flash gpt2 variant: {e}") - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -525,7 +562,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -535,25 +572,28 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: - return FlashNeoXSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: - return GPTNeoxSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=GPTNeoxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -564,16 +604,18 @@ def get_model( elif model_type == PHI: if FLASH_ATTENTION: - return FlashPhi( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -588,9 +630,11 @@ def get_model( "Legacy phi-msft is not supported with Flash Attention" ) else: - return Phi( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=PhiForCausalLM, + config_class=PhiConfig, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -599,9 +643,10 @@ def get_model( elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: - return FlashLlama( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -611,7 +656,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -621,18 +666,22 @@ def get_model( ) if model_type == GEMMA: if FLASH_ATTENTION: - return FlashGemma( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -642,18 +691,22 @@ def get_model( ) elif model_type == GEMMA2: if FLASH_ATTENTION: - return FlashGemma2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -664,18 +717,20 @@ def get_model( if model_type == COHERE: if FLASH_ATTENTION: - return FlashCohere( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -686,18 +741,23 @@ def get_model( if model_type == DBRX: if FLASH_ATTENTION: - return FlashDbrx( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -711,27 +771,37 @@ def get_model( if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) else: - return RW( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -742,18 +812,20 @@ def get_model( if model_type == MISTRAL: if FLASH_ATTENTION: - return FlashMistral( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -764,18 +836,20 @@ def get_model( if model_type == MIXTRAL: if FLASH_ATTENTION: - return FlashMixtral( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -786,19 +860,22 @@ def get_model( if model_type == STARCODER2: if FLASH_ATTENTION: - return FlashStarcoder2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -809,17 +886,20 @@ def get_model( if model_type == QWEN2: if FLASH_ATTENTION: - return FlashQwen2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -829,9 +909,10 @@ def get_model( ) if model_type == OPT: - return OPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -839,13 +920,20 @@ def get_model( ) if model_type == T5: - return T5Sharded( - model_id, - revision, + return Seq2SeqLM( + model_id=model_id, + model_class=T5ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) if model_type == IDEFICS: if FLASH_ATTENTION: @@ -861,34 +949,45 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS2: if FLASH_ATTENTION: - return Idefics2( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "paligemma": + if model_type == PALIGEMMA: if FLASH_ATTENTION: - return PaliGemma( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: - return LlavaNext( - model_id, - revision, + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -912,7 +1011,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -921,7 +1020,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, @@ -933,7 +1032,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -942,7 +1041,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 17aa12e8..732b4c53 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -4,22 +4,12 @@ import torch.distributed from typing import Optional, Type from transformers import ( - AutoTokenizer, - AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models.custom_modeling.bloom_modeling import ( - BloomForCausalLM, -) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) class BloomCausalLMBatch(CausalLMBatch): @@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch): class BLOOMSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - slow_but_exact=False, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.pad_token_id = 3 - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - prefix="transformer", - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = BloomForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 10c64c66..cac36ebd 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,13 +1,25 @@ import torch import time +import torch.distributed from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -478,10 +490,87 @@ class CausalLMBatch(Batch): return len(self.requests) +@dataclass +class CausalLMBatchKeysLast(Batch): + keys_head_dim_last: bool = False + + class CausalLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + batch_class=CausalLMBatch, + ): + self.batch_class = batch_class + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = config.pad_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -537,7 +626,12 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, self).__init__( + self = cls.__new__( + cls, + ) + self.batch_class = CausalLMBatch + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -545,15 +639,11 @@ class CausalLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + return self.batch_class def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index cfa6b2fe..625baa91 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 842df0d4..b7ce6307 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 30989a37..2bc305fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -464,8 +464,9 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() + config.transpose = config.architectures[0].startswith("GPT2") self.transformer = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 6d38442c..567131ef 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module): self.config = config config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator - self.language_model = load_text_model( + self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, @@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module): image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, ): - inputs_embeds = self.language_model.embed_tokens(input_ids) + inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" @@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module): input_ids, inputs_embeds, image_features ) - hidden_states = self.language_model.model( + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.language_model.lm_head(hidden_states) + logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4f276ed4..c7f5f1f9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,7 +10,12 @@ import numpy as np from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) from typing import Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata @@ -21,6 +26,12 @@ from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + hub, +) from text_generation_server.models.types import ( Batch, Tokens, @@ -799,29 +810,110 @@ class FlashCausalLMBatch(Batch): return len(self.requests) +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + class FlashCausalLM(Model): def __init__( self, model_id: str, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - num_layers: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - sliding_window: Optional[int] = None, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, + default_dtype=torch.float16, + aliases=None, + # Used for Santacoder override of config + num_kv_heads=None, + skip_special_tokens: bool = True, ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_size = head_size + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError(f"{model_class} is only available on GPU") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = config_class.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device, dtype, process_group=self.process_group, aliases=aliases + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + prefix = "" + model = model_class(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + + # VLM models define the config we care about in their text_config + text_config = getattr(config, "text_config", None) + if text_config is not None: + config = text_config + self.num_layers = config.num_hidden_layers + # Validation is done in the model itself + if num_kv_heads is None: + num_kv_heads = getattr(config, "num_key_value_heads", None) + if num_kv_heads is None: + # Final overide for GPT2 + num_kv_heads = config.n_head + self.num_kv_heads = num_kv_heads // self.process_group.size() + self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} self.kv_cache = [] - super(FlashCausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, @@ -830,7 +922,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, + sliding_window=config.sliding_window, ) @property @@ -1578,3 +1670,72 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py deleted file mode 100644 index 9f8bcb3f..00000000 --- a/server/text_generation_server/models/flash_cohere.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer, AutoConfig - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( - FlashCohereForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashCohere(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashCohere is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashCohereForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashCohere, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py deleted file mode 100644 index 2aba6a00..00000000 --- a/server/text_generation_server/models/flash_dbrx.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( - FlashDbrxForCausalLM, - DbrxConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashDbrx(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashDBRX is only available on GPU") - - try: - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = DbrxConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashDbrxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashDbrx, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py deleted file mode 100644 index 7e2b8780..00000000 --- a/server/text_generation_server/models/flash_gemma.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - FlashGemmaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py deleted file mode 100644 index 86cfc7e2..00000000 --- a/server/text_generation_server/models/flash_gemma2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import PretrainedConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( - FlashGemma2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = PretrainedConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py deleted file mode 100644 index 323fcafa..00000000 --- a/server/text_generation_server/models/flash_gpt2.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers.models.gpt2 import GPT2Tokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( - FlashGPT2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGPT2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGPT2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashGPT2ForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashGPT2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py deleted file mode 100644 index d996b9c3..00000000 --- a/server/text_generation_server/models/flash_llama.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, - hub, -) - -tracer = trace.get_tracer(__name__) - -from text_generation_server.utils.import_utils import SYSTEM - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashLlama(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - lora_adapter_ids: Optional[list] = [], - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashLlama is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - try: - generation_config = GenerationConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - if isinstance(generation_config.eos_token_id, (list, set)): - # TODO Huge hack - tokenizer._eos_token_ids = set(generation_config.eos_token_id) - except Exception: - pass - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashLlamaForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashLlama, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0f5746de..2b2bd2e0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,24 +1,7 @@ import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import set_sliding_window -from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - MistralConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) ADAPTER_LAYERS = [ @@ -33,88 +16,7 @@ ADAPTER_LAYERS = [ ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} -class BaseFlashMistral(FlashCausalLM): - def __init__( - self, - model_cls, - model_id: str, - config_cls=AutoConfig, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashMistral is only available on GPU") - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = config_cls.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = model_cls(prefix, config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - num_layers, num_kv_heads, head_size = self.get_layer_config(model) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.model.layers), - model.model.num_key_value_heads, - model.model.head_size, - ) - +class FlashMistral(FlashCausalLM): @property def supports_adapter_loading(self) -> bool: return True @@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM): # This accounts for VLMs (e.g. LlavaNext, Idefics2) # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): + if hasattr(self.model, "text_model"): _model = self.model.text_model else: _model = self.model @@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM): def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - - -class FlashMistral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMistral, self).__init__( - config_cls=MistralConfig, - model_cls=FlashMistralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py deleted file mode 100644 index 587d423f..00000000 --- a/server/text_generation_server/models/flash_mixtral.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from typing import Optional - -from text_generation_server.models.flash_mistral import BaseFlashMistral -from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( - MixtralConfig, - FlashMixtralForCausalLM, -) - - -class FlashMixtral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMixtral, self).__init__( - config_cls=MixtralConfig, - model_cls=FlashMixtralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py deleted file mode 100644 index ac1fd573..00000000 --- a/server/text_generation_server/models/flash_neox.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashNeoXSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashGPTNeoXForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashNeoXSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.gpt_neox.layers), - num_kv_heads=model.gpt_neox.num_heads, - head_size=model.gpt_neox.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py deleted file mode 100644 index a530d1c3..00000000 --- a/server/text_generation_server/models/flash_phi.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_phi_modeling import ( - FlashPhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashPhi(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashPhi is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashPhiForCausalLM(config, weights) - if speculator: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(speculator).exists() and Path(speculator).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - speculator, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - speculator, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(speculator) / "config.json") - medusa_head = str(Path(speculator) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - - torch.distributed.barrier(group=self.process_group) - super(FlashPhi, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py deleted file mode 100644 index cd6078f1..00000000 --- a/server/text_generation_server/models/flash_qwen2.py +++ /dev/null @@ -1,93 +0,0 @@ -import math - -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, -) -from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( - Qwen2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashQwen2(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashQwen2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = Qwen2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py deleted file mode 100644 index b1f75adc..00000000 --- a/server/text_generation_server/models/flash_rw.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_rw_modeling import ( - RWConfig, - FlashRWForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashRWSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashRW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - ) - - config.quantize = quantize - config.speculator = speculator - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashRWForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashRWSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=model.transformer.cache_size, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py deleted file mode 100644 index e1a7b36e..00000000 --- a/server/text_generation_server/models/flash_santacoder.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List -import json -import os - -from huggingface_hub import hf_hub_download -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashSantacoderSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=True, - ) - config.quantize = quantize - config.speculator = speculator - config.transpose = config.architectures[0].startswith("GPT2") - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={"transformer.wte.weight": ["lm_head.weight"]}, - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashSantacoderForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashSantacoderSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=1, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py deleted file mode 100644 index 369e9e4c..00000000 --- a/server/text_generation_server/models/flash_starcoder2.py +++ /dev/null @@ -1,84 +0,0 @@ -import math - -import torch - -from typing import Optional - -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, -) -from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( - Starcoder2Config, - FlashStarcoder2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -# Starcoder2 has the same base as Mistral -class FlashStarcoder2(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashStarcoder2 is only available on GPU") - - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = Starcoder2Config.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashStarcoder2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 30c92d90..2d43244a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch): padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) - - -class GalacticaSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - tokenizer.pad_token_id = config.pad_token_id - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py deleted file mode 100644 index c37cfb7d..00000000 --- a/server/text_generation_server/models/gpt_neox.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.neox_modeling import ( - GPTNeoxForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class GPTNeoxSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = GPTNeoxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py deleted file mode 100644 index 314c0500..00000000 --- a/server/text_generation_server/models/idefics2.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.idefics2 import ( - Idefics2ForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class Idefics2(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - # XXX: Extremely important to cap resolution in order to limit - # VRAM usage. - size={"longest_edge": 448, "shortest_edge": 378}, - ) - super().__init__( - model_cls=Idefics2ForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py deleted file mode 100644 index effe8b91..00000000 --- a/server/text_generation_server/models/llava_next.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class LlavaNext(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - super().__init__( - model_cls=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py deleted file mode 100644 index 1e79b25f..00000000 --- a/server/text_generation_server/models/mpt.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.distributed - -from pathlib import Path -from typing import Optional, Type -from opentelemetry import trace -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase -from huggingface_hub import hf_hub_download -import json - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.custom_modeling.mpt_modeling import ( - MPTForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class MPTCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) - batch.keys_head_dim_last = False - return batch - - -class MPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - # If model_id is a local path, load the file directly - local_path = Path(model_id, "config.json") - if local_path.exists(): - filename = str(local_path.resolve()) - else: - filename = hf_hub_download( - model_id, revision=revision, filename="config.json" - ) - with open(filename, "r") as f: - config = json.load(f) - config = PretrainedConfig(**config) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - config.quantize = quantize - model = MPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return MPTCausalLMBatch diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py deleted file mode 100644 index 6d7d07f5..00000000 --- a/server/text_generation_server/models/opt.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM -from text_generation_server.models import CausalLM -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class OPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - tokenizer.pad_token_id = config.pad_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index a167e467..3994ac70 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch): else: image_inputs = None return batch_tokenized_inputs, image_inputs - - -class PaliGemma(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - - super().__init__( - config_cls=AutoConfig, - model_cls=PaliGemmaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self): - return PaliGemmaBatch - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py deleted file mode 100644 index 93d42b2b..00000000 --- a/server/text_generation_server/models/phi.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.distributed - -from transformers import AutoConfig, AutoTokenizer -from typing import Optional, List, Tuple - -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.phi_modeling import ( - PhiConfig, - PhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class Phi(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, _rank, _world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = PhiConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - model = PhiForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py deleted file mode 100644 index 37ca277b..00000000 --- a/server/text_generation_server/models/rw.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch - -from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import List, Optional, Tuple - -from text_generation_server.models import CausalLM - - -class RW(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if speculator: - raise RuntimeError("Medusa decoding is not enabled for AutoModel") - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py deleted file mode 100644 index caddbe19..00000000 --- a/server/text_generation_server/models/santacoder.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional, List -from transformers import AutoTokenizer, AutoModelForCausalLM - -from text_generation_server.models import CausalLM - -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" -EOD = "<|endoftext|>" - - -class SantaCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.add_special_tokens( - { - "additional_special_tokens": [ - EOD, - FIM_PREFIX, - FIM_MIDDLE, - FIM_SUFFIX, - FIM_PAD, - ], - "pad_token": EOD, - } - ) - with device: - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d454d804..dbaf1253 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,11 +1,22 @@ import torch +import torch.distributed import time from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + PreTrainedTokenizerBase, + AutoConfig, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model @@ -531,6 +542,80 @@ class Seq2SeqLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + config_class=AutoConfig, + tokenizer_class=AutoTokenizer, + aliases=None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.bos_token_id = config.decoder_start_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases=aliases, + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -574,7 +659,11 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = model.config.decoder_start_token_id - super(Seq2SeqLM, self).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -582,16 +671,12 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch - def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode( - decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py deleted file mode 100644 index adef664c..00000000 --- a/server/text_generation_server/models/t5.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.distributed - -from typing import List, Optional, Tuple - -from transformers import ( - AutoTokenizer, - AutoConfig, -) - -from text_generation_server.models import Seq2SeqLM -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class T5Sharded(Seq2SeqLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.bos_token_id = config.decoder_start_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={ - "shared.weight": [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - }, - ) - - model = T5ForConditionalGeneration(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(Seq2SeqLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - ]: - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_last_hidden_state, - past_key_values=past_key_values, - use_cache=True, - ) - - return ( - outputs.logits, - speculative_logits, - outputs.encoder_last_hidden_state, - outputs.past_key_values, - ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1cdf37ea..ace48805 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, ) +from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch -class VlmCausalLM(BaseFlashMistral): +class VlmCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__(model_id=model_id, **kwargs) + @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) def forward( self, From e481a9bb9b8af3f599db846da7a79facb8f0fd1a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Jul 2024 09:25:29 +0000 Subject: [PATCH 109/340] Hotfixing after refactor. --- .../custom_modeling/flash_santacoder_modeling.py | 16 ++++++++-------- server/text_generation_server/models/model.py | 4 +++- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2bc305fe..daef43cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -355,7 +355,7 @@ class Block(nn.Module): self.ln_2 = FastLayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon ) - self.attn = FlashMQAttention( + self.self_attn = FlashMQAttention( prefix=f"{prefix}.attn", config=config, weights=weights, @@ -378,7 +378,7 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( + hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, @@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module): reduce=False, ) - self.h = nn.ModuleList( + self.layers = nn.ModuleList( [ Block( layer_id, @@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module): prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon ) - self.head_size = self.h[0].attn.head_size - self.num_heads = self.h[0].attn.num_heads + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads def forward( self, @@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None - for i, layer in enumerate(self.h): + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, @@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.transpose = config.architectures[0].startswith("GPT2") - self.transformer = FlashSantacoderModel(config, weights) + self.model = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) @@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer( + hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index c90fd38a..09130b85 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -60,7 +60,7 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = self.adapter_target_to_layer() + self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -187,6 +187,8 @@ class Model(ABC): into model. Otherwise, the adapter weights are applied during the forward pass and stored separately from the base model parameters. """ + if self.target_to_layer is None: + self.target_to_layer = self.adapter_target_to_layer() if adapter_index in self.loaded_adapters: # Adapter already loaded return From 1e7ce69f207bb899e4a08fc89ceedb59896e70d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Jul 2024 12:22:45 +0200 Subject: [PATCH 110/340] Fix Starcoder2 after refactor (#2189) --- .../flash_starcoder2_modeling.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index a0273c37..2b346283 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module): class Starcoder2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ @@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module): ] ) self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( - prefix="model.norm", weights=weights, eps=config.norm_epsilon + prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon ) self.gradient_checkpointing = False @@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = Starcoder2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Starcoder2Model(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) From 54c194dfa6a4e5d8d35f87c30b6e757962e3d073 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Jul 2024 14:12:16 +0200 Subject: [PATCH 111/340] GPTQ CI improvements (#2151) * Add more representative Llama GPTQ test The Llama GPTQ test is updated to use a model with the commonly-used quantizer config format and activation sorting. The old test is kept around (but renamed) since it tests the format produced by `text-generation-server quantize`. * Add support for manually triggering a release build --- .github/workflows/ci_build.yaml | 11 +- .../test_flash_llama_gptq.json | 89 ++--- .../test_flash_llama_gptq_all_params.json | 85 ++--- .../test_flash_llama_gptq_load.json | 356 ++++++++--------- .../test_server_gptq_quantized.json | 89 +++++ ...test_server_gptq_quantized_all_params.json | 89 +++++ .../test_server_gptq_quantized_load.json | 358 ++++++++++++++++++ .../models/test_flash_llama_gptq.py | 4 +- 8 files changed, 799 insertions(+), 282 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json create mode 100644 integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 754c4850..d62297e4 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -20,7 +20,14 @@ on: - "Dockerfile_amd" - "Dockerfile_intel" branches: - - 'main' + - "main" + workflow_dispatch: + inputs: + release-tests: + description: "Run release integration tests" + required: true + default: false + type: boolean jobs: build: @@ -33,4 +40,6 @@ jobs: uses: ./.github/workflows/build.yaml # calls the one above ^ with: hardware: ${{ matrix.hardware }} + # https://github.com/actions/runner/issues/2206 + release-tests: ${{ inputs.release-tests == true }} secrets: inherit diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json index 7797cc6c..0f99d259 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -5,85 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.7890625, "text": "Test" }, { - "id": 2009, - "logprob": -9.625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3359375, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8779297, + "id": 262, + "logprob": -1.6230469, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2744141, + "id": 3270, + "logprob": -2.046875, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1425781, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.9238281, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6933594, + "id": 13204, + "logprob": -0.076660156, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4648438, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15600586, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.8027344, + "id": 3019, + "logprob": -0.10821533, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.23022461, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.0069885254, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.02218628, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json index fa2fd4a2..4152b5b3 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -5,85 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.6015625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": 0, "tokens": [ { - "id": 29899, - "logprob": -1.5625, + "id": 13, + "logprob": -2.2539062, "special": false, - "text": "-" + "text": "." }, { - "id": 1454, - "logprob": -0.20410156, + "id": 578, + "logprob": -0.15563965, "special": false, - "text": "for" + "text": " The" }, { - "id": 29899, + "id": 3622, + "logprob": -0.8203125, + "special": false, + "text": " server" + }, + { + "id": 706, "logprob": 0.0, "special": false, - "text": "-" + "text": " has" }, { - "id": 9342, + "id": 539, "logprob": 0.0, "special": false, - "text": "comment" + "text": " not" }, { - "id": 29901, + "id": 3686, "logprob": 0.0, "special": false, - "text": ":" + "text": " yet" }, { - "id": 396, - "logprob": -0.27685547, - "special": false, - "text": " #" - }, - { - "id": 29906, - "logprob": -0.4970703, - "special": false, - "text": "2" - }, - { - "id": 29900, - "logprob": -0.80615234, - "special": false, - "text": "0" - }, - { - "id": 29896, + "id": 3288, "logprob": 0.0, "special": false, - "text": "1" + "text": " sent" }, { - "id": 29955, - "logprob": -1.0751953, + "id": 904, + "logprob": 0.0, "special": false, - "text": "7" + "text": " any" + }, + { + "id": 828, + "logprob": 0.0, + "special": false, + "text": " data" + }, + { + "id": 382, + "logprob": -1.5517578, + "special": false, + "text": ".\n\n" } ], "top_tokens": null }, - "generated_text": "Test request-for-comment: #2017" + "generated_text": "Test request. The server has not yet sent any data.\n\n" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json index 594b7351..75e90303 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -6,87 +6,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.828125, "text": "Test" }, { - "id": 2009, - "logprob": -9.609375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3300781, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8740234, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2646484, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.7158203, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4667969, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15344238, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.81591797, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22973633, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.007045746, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021957397, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -95,87 +90,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.59375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3378906, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8779297, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2636719, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6992188, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4589844, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15344238, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.79052734, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22937012, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.007041931, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.022140503, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -184,87 +174,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.609375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3261719, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8730469, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2587891, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6894531, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.46875, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.1541748, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.80322266, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22912598, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.0070495605, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021606445, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -273,86 +258,81 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.6015625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3320312, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.875, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2646484, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6884766, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4589844, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15185547, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.79833984, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22827148, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.006996155, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021560669, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" } ] diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json new file mode 100644 index 00000000..69c1f47d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.8359375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6171875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3417969, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8730469, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2626953, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7060547, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4482422, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15246582, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.796875, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22766113, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007045746, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021759033, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" +} diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json new file mode 100644 index 00000000..9b5ee9ee --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.7890625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 29899, + "logprob": -1.4980469, + "special": false, + "text": "-" + }, + { + "id": 1454, + "logprob": -0.19433594, + "special": false, + "text": "for" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 9342, + "logprob": 0.0, + "special": false, + "text": "comment" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 396, + "logprob": -0.27392578, + "special": false, + "text": " #" + }, + { + "id": 29906, + "logprob": -0.49389648, + "special": false, + "text": "2" + }, + { + "id": 29900, + "logprob": -0.81103516, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": 0.0, + "special": false, + "text": "1" + }, + { + "id": 29955, + "logprob": -1.0800781, + "special": false, + "text": "7" + } + ], + "top_tokens": null + }, + "generated_text": "Test request-for-comment: #2017" +} diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json new file mode 100644 index 00000000..df975635 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.8828125, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.5859375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3359375, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8623047, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2451172, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6923828, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4492188, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15197754, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.8022461, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22583008, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007095337, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021652222, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.796875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3476562, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8789062, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2734375, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.703125, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4677734, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15454102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.7973633, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.23278809, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.006980896, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.022033691, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.9296875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.5703125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3203125, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8486328, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2480469, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7060547, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4511719, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1529541, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.81396484, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22180176, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007133484, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021835327, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6171875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3261719, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8691406, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2597656, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7070312, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4550781, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1538086, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.79345703, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22924805, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0070266724, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021942139, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + } +] diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index 135f4b05..94a48e49 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def flash_llama_gptq_handle(launcher): - with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: + with launcher( + "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq" + ) as handle: yield handle From 508e3080884b578c98795c44e0dea8a8b87979aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Jul 2024 16:07:48 +0200 Subject: [PATCH 112/340] Consistently take `prefix` in model constructors (#2191) * Consistently take `prefix` in model constructors * Release test check fix * Misc refactor-related fixes --- .../text_generation_server/models/__init__.py | 3 +- .../models/causal_lm.py | 3 +- .../models/custom_modeling/bloom_modeling.py | 2 +- .../models/custom_modeling/clip.py | 6 +-- .../custom_modeling/flash_cohere_modeling.py | 22 ++++++---- .../custom_modeling/flash_dbrx_modeling.py | 18 +++++--- .../custom_modeling/flash_gemma2_modeling.py | 10 ++--- .../custom_modeling/flash_gemma_modeling.py | 12 +++--- .../custom_modeling/flash_gpt2_modeling.py | 8 ++-- .../custom_modeling/flash_llama_modeling.py | 4 +- .../custom_modeling/flash_mistral_modeling.py | 8 ++-- .../custom_modeling/flash_mixtral_modeling.py | 10 ++--- .../custom_modeling/flash_neox_modeling.py | 16 +++++--- .../custom_modeling/flash_phi_modeling.py | 18 +++++--- .../custom_modeling/flash_qwen2_modeling.py | 20 +++++---- .../custom_modeling/flash_rw_modeling.py | 34 ++++++++------- .../flash_santacoder_modeling.py | 21 ++++++---- .../models/custom_modeling/mpt_modeling.py | 20 +++++---- .../models/custom_modeling/neox_modeling.py | 28 ++++++++----- .../models/custom_modeling/opt_modeling.py | 41 ++++++++++++------- .../models/custom_modeling/phi_modeling.py | 16 +++++--- .../models/flash_causal_lm.py | 19 +++++---- 22 files changed, 209 insertions(+), 130 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 15e74622..58131a3a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.opt_modeling import OPTForCau from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, ) +from text_generation_server.models.bloom import BloomCausalLMBatch from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) @@ -522,7 +523,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, - batch_class=CausalLMBatchKeysLast, + batch_class=BloomCausalLMBatch, ) elif model_type == MPT: return CausalLM( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cac36ebd..868a3cc0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -553,7 +553,8 @@ class CausalLM(Model): if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) - model = model_class(config, weights) + prefix = "" + model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super().__init__( diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 0d8a1b59..77b89c5b 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.transformer = BloomModel(config, weights) diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 56618bf1..27b9ff1c 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module): class CLIPTextTransformer(nn.Module): - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix: str, config: CLIPTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size @@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel): _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix, config: CLIPTextConfig): super().__init__(config) - self.text_model = CLIPTextTransformer(config) + self.text_model = CLIPTextTransformer(prefix, config) # Initialize weights and apply final processing self.post_init() diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e088f9aa..f993fe72 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -363,9 +363,9 @@ class CohereMLP(nn.Module): class FlashCohereLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashCohereAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module): class FlashCohereModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashCohereLayer( + prefix, layer_id, config, weights, @@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="model.norm", weights=weights, eps=config.layer_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps ) self.gradient_checkpointing = False @@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module): class FlashCohereForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashCohereModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashCohereModel(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) self.logit_scale = config.logit_scale diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index aea7f399..e469495f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -593,9 +593,9 @@ class DenseMoE(nn.Module): class DbrxLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.blocks.{layer_id}" + prefix = f"{prefix}.blocks.{layer_id}" self.attn = DbrxNormAttentionNorm( prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights @@ -637,16 +637,17 @@ class DbrxLayer(nn.Module): class DbrxModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights + prefix=f"{prefix}.wte", weights=weights ) self.layers = nn.ModuleList( [ DbrxLayer( + prefix, layer_id, config, weights, @@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=1e-5 + prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5 ) self.head_size = self.layers[0].attn.self_attn.head_size @@ -702,9 +703,14 @@ class DbrxModel(torch.nn.Module): class FlashDbrxForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + self.model = DbrxModel(config, weights) self.lm_head = SpeculativeHead.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 625baa91..beff08b3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig): class Gemma2FastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module): class FlashGemma2Layer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): + def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): super().__init__() self.self_attn = FlashGemma2Attention( prefix=f"{prefix}.self_attn", @@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module): class FlashGemma2Model(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, *, causal: bool = True): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index b7ce6307..14b62b00 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig): class GemmaFastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module): class GemmaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -299,7 +299,7 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal @@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module): class FlashGemmaModel(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, *, causal: bool = True): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 9f800146..d5dc25cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module): class GPT2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.activation_function self.act = ( @@ -298,7 +298,7 @@ class GPT2MLP(nn.Module): class FlashGPT2Layer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.self_attn = FlashGPT2Attention( prefix=f"{prefix}.attn", config=config, weights=weights @@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module): class FlashGPT2Model(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module): class FlashGPT2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 77a7e2d5..78832341 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -54,7 +54,7 @@ if SYSTEM == "rocm": raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -def load_attention(config, prefix, weights, layer_id): +def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) head_size = config.hidden_size // config.num_attention_heads @@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 396969cd..8028dbe8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module): class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -328,7 +328,7 @@ class MistralMLP(nn.Module): class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", @@ -392,7 +392,7 @@ class MistralLayer(nn.Module): class MistralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, name=None): + def __init__(self, prefix: str, config, weights, name=None): if name is None: name = "model" super().__init__() diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2d6a7f97..429793ea 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights): ) -def _load_experts(config, prefix, mat, weights): +def _load_experts(config, prefix: str, mat, weights): if config.quantize is not None: raise NotImplementedError("Mixtral does not support weight quantization yet.") @@ -475,7 +475,7 @@ class DenseMoE(nn.Module): class MixtralLayer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() prefix = f"{prefix}.layers.{layer_id}" @@ -536,7 +536,7 @@ class MixtralLayer(nn.Module): class MixtralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.model = MixtralModel(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33aebc2b..0eca181b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( @@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ] ) self.final_layer_norm = FastLayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) - self.gpt_neox = FlashGPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f237ea37..7401bc27 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -258,9 +258,9 @@ class PhiMLP(nn.Module): class FlashPhiLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashPhiAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module): class FlashPhiModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashPhiLayer( + prefix, layer_id, config, weights, @@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashPhiModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashPhiModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 1cc6a613..a98709c5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module): class Qwen2Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = Qwen2Attention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module): class Qwen2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ Qwen2Layer( + prefix, layer_id, config, weights, @@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module): ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module): class Qwen2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = Qwen2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Qwen2Model(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7614232..d12ed567 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.act = torch.nn.functional.gelu @@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module): def __init__( self, layer_id, + prefix: str, config, weights, ): @@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module): parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", @@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.num_ln = config.num_ln_in_parallel_attn @@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module): class FlashRWLargeLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_layer = FlashRWLayerNorm(config, prefix, weights) @@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.word_embeddings = TensorParallelEmbedding( - prefix="transformer.word_embeddings", weights=weights + prefix=f"{prefix}.word_embeddings", weights=weights ) if config.new_decoder_architecture: self.h = nn.ModuleList( [ - FlashRWLargeLayer(layer_id, config, weights) + FlashRWLargeLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel): else: self.h = nn.ModuleList( [ - FlashRWLayer(layer_id, config, weights) + FlashRWLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = self.h[0].self_attention.num_heads_kv self.ln_f = FastLayerNorm.load( - prefix="transformer.ln_f", + prefix=f"{prefix}.ln_f", weights=weights, eps=config.layer_norm_epsilon, ) @@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.transformer = FlashRWModel(config, weights) + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.transformer = FlashRWModel(prefix, config, weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index daef43cc..21a22046 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -346,9 +346,9 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_1 = FastLayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) @@ -396,18 +396,18 @@ class Block(nn.Module): class FlashSantacoderModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( - prefix="transformer.wte", + prefix=f"{prefix}.wte", weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( - prefix="transformer.wpe", + prefix=f"{prefix}.wpe", weights=weights, reduce=False, ) @@ -415,6 +415,7 @@ class FlashSantacoderModel(nn.Module): self.layers = nn.ModuleList( [ Block( + prefix, layer_id, config, weights, @@ -466,10 +467,16 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + config.transpose = config.architectures[0].startswith("GPT2") - self.model = FlashSantacoderModel(config, weights) + self.model = FlashSantacoderModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index f7981bf5..fb09a8f1 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel): class MPTModel(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): # config._validate_config() super().__init__(config) self.world_size = weights.process_group.size() @@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel): f"Requested norm type ({config.norm_type}) is not implemented within this repo." ) - self.wte = TensorParallelEmbedding("transformer.wte", weights) + self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) if not self.alibi: - self.wpe = TensorParallelEmbedding("transformer.wpe", weights) + self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) self.blocks = nn.ModuleList( [ - MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) + MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) for i in range(config.n_layers) ] ) @@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel): class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(config, weights) + self.transformer = MPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) self.logit_scale = None if config.logit_scale is not None: diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index fcad32fa..8998778f 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module): class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + prefix=f"{prefix}.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.attention = GPTNeoXAttention( - config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights ) self.mlp = GPTNeoXMLP( - config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights ) def forward( @@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module): class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.num_attention_heads = config.num_attention_heads self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( [ - GPTNeoXLayer(layer_id, config, weights) + GPTNeoXLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.final_layer_norm = nn.LayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.gpt_neox = GPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = GPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 9b2d01e0..5ab02959 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module): This module learns positional embeddings up to a fixed maximum size. """ - def __init__(self, weights): + def __init__(self, prefix: str, weights): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor("model.decoder.embed_positions.weight") + weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") ) def forward( @@ -311,11 +311,11 @@ class OPTAttention(nn.Module): class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, config: OPTConfig, weights): + def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"model.decoder.layers.{layer_id}" + prefix = f"{prefix}.decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel): class OPTDecoder(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.layerdrop @@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = TensorParallelEmbedding( - prefix="model.decoder.embed_tokens", weights=weights + prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) - self.embed_positions = OPTLearnedPositionalEmbedding(weights) + self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( - config, prefix="model.decoder.project_out", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_out", + weights=weights, + bias=False, ) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( - config, prefix="model.decoder.project_in", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_in", + weights=weights, + bias=False, ) else: self.project_in = None @@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None self.layers = nn.ModuleList( [ - OPTDecoderLayer(layer_id, config, weights) + OPTDecoderLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel): class OPTModel(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) - self.decoder = OPTDecoder(config, weights) + self.decoder = OPTDecoder(prefix, config, weights) # Initialize weights and apply final processing def forward( @@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + self.model = OPTModel(config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="model.decoder.embed_tokens", weights=weights + config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index 04b470eb..b4d56db1 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -248,16 +248,16 @@ class PhiBlock(nn.Module): # PhiModel implements the embedding layer and the transformer blocks. class PhiModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.tp_rank = weights.process_group.rank() self.tp_world_size = weights.process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.embd.wte", weights=weights + prefix=f"{prefix}.embd.wte", weights=weights ) self.blocks = nn.ModuleList( [ - PhiBlock(f"transformer.h.{layer_id}", config, weights) + PhiBlock(f"{prefix}.h.{layer_id}", config, weights) for layer_id in range(config.n_layer) ] ) @@ -289,9 +289,15 @@ class PhiModel(nn.Module): # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. class PhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = PhiModel(config, weights) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = PhiModel(prefix, config, weights) self.lm_head = PhiCausalLMHead(config, weights) def forward( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7f5f1f9..e66011a1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -878,10 +878,6 @@ class FlashCausalLM(Model): ) config.quantize = quantize config.speculator = speculator - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None torch.distributed.barrier(group=self.process_group) @@ -900,13 +896,22 @@ class FlashCausalLM(Model): text_config = getattr(config, "text_config", None) if text_config is not None: config = text_config + + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - num_kv_heads = getattr(config, "num_key_value_heads", None) + # Order is important here. + for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: + num_kv_heads = getattr(config, "num_attention_heads", None) + if num_kv_heads is not None: + break if num_kv_heads is None: - # Final overide for GPT2 - num_kv_heads = config.n_head + raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads From 8e3d1e6c3f3791b883f4f9a154adf3a717d4a1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?icyboy=E2=84=A2?= Date: Mon, 8 Jul 2024 15:01:14 +0800 Subject: [PATCH 113/340] fix dbrx & opt model prefix bug (#2201) * Update idefics_causal_lm.py Fix syntax issues * fix dbrx & opt model prefix bug --- .../models/custom_modeling/flash_dbrx_modeling.py | 2 +- .../models/custom_modeling/opt_modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index e469495f..41aa5859 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -711,7 +711,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): else: prefix = f"{prefix}.transformer" - self.model = DbrxModel(config, weights) + self.model = DbrxModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 5ab02959..84a1c069 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -757,7 +757,7 @@ class OPTForCausalLM(OPTPreTrainedModel): else: prefix = f"{prefix}.model" - self.model = OPTModel(config, weights) + self.model = OPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights From f11fd699b656fb161d9fe42d45b40fd7d9f1cd90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 09:52:12 +0200 Subject: [PATCH 114/340] hotfix: Fix number of KV heads (#2202) Fix number of KV heads --- server/text_generation_server/models/flash_causal_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e66011a1..42b2f686 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -906,8 +906,8 @@ class FlashCausalLM(Model): # Validation is done in the model itself if num_kv_heads is None: # Order is important here. - for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: - num_kv_heads = getattr(config, "num_attention_heads", None) + for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: + num_kv_heads = getattr(config, attr, None) if num_kv_heads is not None: break if num_kv_heads is None: From 17594916edfa71418de2adecfff6ca667b78e70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 11:19:48 +0200 Subject: [PATCH 115/340] Fix incorrect cache allocation with multi-query (#2203) We wouldn't allocate any memory in multi-query (1 KV head). Fixes Starcoder et al. --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42b2f686..5c086a73 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -912,7 +912,12 @@ class FlashCausalLM(Model): break if num_kv_heads is None: raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = num_kv_heads // self.process_group.size() + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} From 540e710c3f837745b48bf25b2854130c5a526554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 13:22:38 +0200 Subject: [PATCH 116/340] Falcon/DBRX: get correct number of key-value heads (#2205) --- server/text_generation_server/models/__init__.py | 4 ++++ .../models/custom_modeling/flash_dbrx_modeling.py | 12 ++++++++++++ .../models/custom_modeling/flash_rw_modeling.py | 1 + .../text_generation_server/models/flash_causal_lm.py | 11 +++++------ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 58131a3a..ba980195 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -797,6 +797,10 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 41aa5859..44411687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig): + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + def __init__( self, d_model: int = 2048, @@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig): **kwargs, ) + @property + def num_key_value_heads(self): + # We can't use the attribute map, since this the number of KV + # heads is not top-level. + return self.attn_config.kv_n_heads + def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d12ed567..4813e2df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig): attribute_map = { "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", + "num_key_value_heads": "n_head_kv", } def __init__( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5c086a73..bf1fda4a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -905,13 +905,12 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - # Order is important here. - for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: - num_kv_heads = getattr(config, attr, None) - if num_kv_heads is not None: - break + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround if num_kv_heads is None: - raise ValueError("Cannot get the number of key/value heads") + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = ( num_kv_heads // self.process_group.size() if num_kv_heads > 1 From 8dd9b2b13537d0d607941f120ead1db5b30156fc Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 8 Jul 2024 21:57:06 +0800 Subject: [PATCH 117/340] add doc for intel gpus (#2181) Signed-off-by: Wang, Yi A --- docs/source/_toctree.yml | 2 ++ docs/source/architecture.md | 1 + docs/source/installation_intel.md | 19 +++++++++++++++++++ docs/source/quicktour.md | 2 +- 4 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 docs/source/installation_intel.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c9b4efd9..119c5662 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -11,6 +11,8 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_intel + title: Using TGI with Intel GPUs - local: installation title: Installation from source - local: supported_models diff --git a/docs/source/architecture.md b/docs/source/architecture.md index a8418817..28c84f62 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ. - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md new file mode 100644 index 00000000..f9fda863 --- /dev/null +++ b/docs/source/installation_intel.md @@ -0,0 +1,19 @@ +# Using TGI with Intel GPUs + +TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker. + + +On a server powered by Intel GPUs, TGI can be launched with the following command: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --rm --privileged --cap-add=sys_nice \ + --device=/dev/dri \ + --ipc=host --shm-size 1g --net host -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:latest-intel \ + --model-id $model --cuda-graphs 0 +``` + +The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index c546bc03..f056baad 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ ### Supported hardware -TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. +TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. ## Consuming TGI From 4a54e41920499341eba2e12b8a4a5e5bfa18671e Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Mon, 8 Jul 2024 15:59:16 +0200 Subject: [PATCH 118/340] fix: python deserialization (#2178) --- clients/python/text_generation/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index a56edaca..e36dd470 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] + tool_calls: Optional[ChoiceDeltaToolCall] = None class Choice(BaseModel): From 74edda9c235bfd0d076e85a4816003e6149b1085 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 8 Jul 2024 22:03:59 +0800 Subject: [PATCH 119/340] =?UTF-8?q?update=20to=20metrics=200.23.0=20or=20c?= =?UTF-8?q?ould=20work=20with=20metrics-exporter-promethe=E2=80=A6=20(#219?= =?UTF-8?q?0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit update to metrics 0.23.0 or could work with metrics-exporter-prometheus 0.15.1 Signed-off-by: Wang, Yi A --- Cargo.lock | 28 ++----------- router/Cargo.toml | 2 +- router/src/infer/mod.rs | 8 ++-- router/src/infer/v2/queue.rs | 8 ++-- router/src/infer/v2/scheduler.rs | 61 ++++++++++++++++------------ router/src/infer/v3/queue.rs | 8 ++-- router/src/infer/v3/scheduler.rs | 61 ++++++++++++++++------------ router/src/server.rs | 68 ++++++++++++++------------------ router/src/validation.rs | 4 +- 9 files changed, 119 insertions(+), 129 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 50799b9d..de3977c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2011,17 +2011,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "metrics" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" -dependencies = [ - "ahash", - "metrics-macros", - "portable-atomic", -] - [[package]] name = "metrics" version = "0.23.0" @@ -2045,7 +2034,7 @@ dependencies = [ "hyper-util", "indexmap 2.2.6", "ipnet", - "metrics 0.23.0", + "metrics", "metrics-util", "quanta", "thiserror", @@ -2053,17 +2042,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "metrics-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", -] - [[package]] name = "metrics-util" version = "0.17.0" @@ -2073,7 +2051,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "metrics 0.23.0", + "metrics", "num_cpus", "quanta", "sketches-ddsketch", @@ -3921,7 +3899,7 @@ dependencies = [ "init-tracing-opentelemetry", "itertools 0.10.5", "jsonschema", - "metrics 0.21.1", + "metrics", "metrics-exporter-prometheus", "minijinja", "minijinja-contrib", diff --git a/router/Cargo.toml b/router/Cargo.toml index 5855ac86..60fb5c9d 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -24,7 +24,7 @@ futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.21.1" +metrics = "0.23.0" metrics-exporter-prometheus = { version = "0.15.1", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 49282eb9..f3b10450 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -91,14 +91,14 @@ impl Infer { .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1); tracing::error!("{err}"); err })?; // Validate request let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); err })?; @@ -140,7 +140,7 @@ impl Infer { .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .apply(messages, grammar_with_prompt) .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); + metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); e }) @@ -214,7 +214,7 @@ impl Infer { }) } else { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); Err(err) } diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 93cf9469..0b51645a 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -111,7 +111,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -124,7 +124,7 @@ async fn queue_task( let next_batch = state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); }), } } @@ -226,7 +226,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -336,7 +336,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index e4c3de26..97379bc5 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -148,8 +148,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -170,9 +170,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -219,8 +221,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -234,7 +236,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -248,11 +250,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -261,7 +267,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -276,7 +282,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -291,13 +297,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -307,7 +318,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -381,7 +392,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -407,7 +418,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index ba65b9b6..894d9cab 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -126,7 +126,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -141,7 +141,7 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); } } } @@ -248,7 +248,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -399,7 +399,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 543ce89f..26cd9584 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -154,8 +154,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -176,9 +176,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -225,8 +227,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -240,7 +242,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -254,11 +256,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -267,7 +273,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -282,7 +288,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -297,13 +303,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -313,7 +324,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -387,7 +398,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -413,7 +424,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/server.rs b/router/src/server.rs index db8b16ad..4af8962e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -185,7 +185,7 @@ pub(crate) async fn generate_internal( span: tracing::Span, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // Do not long ultra long inputs, like image payloads. tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); @@ -301,25 +301,15 @@ pub(crate) async fn generate_internal( ); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_validation_duration", - validation_time.as_secs_f64() - ); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_inference_duration", - inference_time.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_mean_time_per_token_duration", - time_per_token.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_generated_tokens", - response.generated_text.generated_tokens as f64 - ); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration") + .record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens") + .record(response.generated_text.generated_tokens as f64); // Send response let mut output_text = response.generated_text.text; @@ -399,7 +389,7 @@ async fn generate_stream_internal( span: tracing::Span, ) -> (HeaderMap, impl Stream>) { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); tracing::debug!("Input: {}", req.inputs); @@ -427,12 +417,12 @@ async fn generate_stream_internal( let best_of = req.parameters.best_of.unwrap_or(1); if best_of != 1 { let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else { @@ -500,13 +490,13 @@ async fn generate_stream_internal( span.record("seed", format!("{:?}", generated_text.seed)); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); - metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64); // StreamResponse end_reached = true; @@ -553,7 +543,7 @@ async fn generate_stream_internal( // Skip if we already sent an error if !end_reached && !error { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } @@ -604,7 +594,7 @@ async fn completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { max_tokens, @@ -625,7 +615,7 @@ async fn completions( // if suffix is present throw an error if req.suffix.is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -637,7 +627,7 @@ async fn completions( } if req.prompt.0.len() > info.max_client_batch_size { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -1009,7 +999,7 @@ async fn chat_completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let ChatRequest { logprobs, max_tokens, @@ -1039,7 +1029,7 @@ async fn chat_completions( // response_format and tools are mutually exclusive if response_format.is_some() && tools.as_ref().is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -1053,7 +1043,7 @@ async fn chat_completions( let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1082,7 +1072,7 @@ async fn chat_completions( let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { Ok(inputs) => inputs, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1280,7 +1270,7 @@ async fn vertex_compatibility( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance if req.instances.is_empty() { diff --git a/router/src/validation.rs b/router/src/validation.rs index 12cf2ab3..07ad14c9 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -157,7 +157,7 @@ impl Validation { )); } - metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, input_length, max_new_tokens)) } // Return inputs without validation @@ -384,7 +384,7 @@ impl Validation { ignore_eos_token: false, }; - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, From 48f1196da88a62fba53028c0402a08aa9e246980 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 8 Jul 2024 10:06:49 -0400 Subject: [PATCH 120/340] feat: use model name as adapter id in chat endpoints (#2128) --- router/src/lib.rs | 4 ++-- router/src/server.rs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 165b2ad2..080c029a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -384,7 +384,7 @@ pub struct CompletionRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] @@ -731,7 +731,7 @@ impl ChatCompletionChunk { pub(crate) struct ChatRequest { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// A list of messages comprising the conversation so far. #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] diff --git a/router/src/server.rs b/router/src/server.rs index 4af8962e..4b52710d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -597,6 +597,7 @@ async fn completions( metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { + model, max_tokens, seed, stop, @@ -665,7 +666,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - ..Default::default() + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), }, }) .collect(); @@ -1001,6 +1002,7 @@ async fn chat_completions( let span = tracing::Span::current(); metrics::counter!("tgi_request_count").increment(1); let ChatRequest { + model, logprobs, max_tokens, messages, @@ -1106,7 +1108,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, - ..Default::default() + adapter_id: model.filter(|m| *m != "tgi").map(String::from), }, }; From eaaea91e2bfdcfbc9d6a2c13d623253181a83143 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:52:10 +0200 Subject: [PATCH 121/340] Fix nccl regression on PyTorch 2.3 upgrade (#2099) * fix nccl issue * add note in dockerfile * use v2.22.3 that also fixes @samsamoa's repro * poetry actually can't handle the conflict between torch and nccl * set LD_PRELOAD --- server/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile b/server/Makefile index 0099c56a..d701c819 100644 --- a/server/Makefile +++ b/server/Makefile @@ -35,5 +35,5 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes + poetry export -o requirements_cuda.txt --without-hashes --with cuda poetry export -o requirements_rocm.txt --without-hashes From 591f9f70ebe10ee0f111285cdfcf78c84e537b9b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Jul 2024 11:13:48 +0200 Subject: [PATCH 122/340] Adding sanity check to openapi docs. --- docs/openapi.json | 64 +++++++++++++++++++++++-- docs/source/basic_tutorials/launcher.md | 19 +------- router/src/server.rs | 6 ++- update_doc.py | 20 ++++++-- 4 files changed, 82 insertions(+), 27 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 03e2d4ff..2a331c80 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -809,7 +809,6 @@ "ChatRequest": { "type": "object", "required": [ - "model", "messages" ], "properties": { @@ -854,7 +853,8 @@ "model": { "type": "string", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "n": { "type": "integer", @@ -1116,7 +1116,6 @@ "CompletionRequest": { "type": "object", "required": [ - "model", "prompt" ], "properties": { @@ -1138,7 +1137,8 @@ "model": { "type": "string", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "prompt": { "$ref": "#/components/schemas/Prompt" @@ -1708,6 +1708,62 @@ } } }, + "MessageChunk": { + "oneOf": [ + { + "type": "object", + "required": [ + "text", + "type" + ], + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "enum": [ + "text" + ] + } + } + }, + { + "type": "object", + "required": [ + "image_url", + "type" + ], + "properties": { + "image_url": { + "$ref": "#/components/schemas/Url" + }, + "type": { + "type": "string", + "enum": [ + "image_url" + ] + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "MessageContent": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageChunk" + } + } + ] + }, "PrefillToken": { "type": "object", "required": [ diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f..1e5b6fd2 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -62,9 +62,7 @@ Options: Possible values: - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from - - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - - marlin: 4 bit quantization. Requires a specific Marlin quantized model: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model @@ -126,7 +124,7 @@ Options: ## MAX_TOP_N_TOKENS ```shell --max-top-n-tokens - This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking + This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking [env: MAX_TOP_N_TOKENS=] [default: 5] @@ -336,13 +334,6 @@ Options: --otlp-endpoint [env: OTLP_ENDPOINT=] -``` -## OTLP_SERVICE_NAME -```shell - --otlp-service-name - [env: OTLP_SERVICE_NAME=] - [default: text-generation-inference.router] - ``` ## CORS_ALLOW_ORIGIN ```shell @@ -416,14 +407,6 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] -``` -## LORA_ADAPTERS -```shell - --lora-adapters - Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request - - [env: LORA_ADAPTERS=] - ``` ## HELP ```shell diff --git a/router/src/server.rs b/router/src/server.rs index 4b52710d..8cc09af3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,8 +13,8 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, - Usage, Validation, + Message, MessageChunk, MessageContent, PrefillToken, SimpleToken, StreamDetails, + StreamResponse, Token, TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1446,6 +1446,8 @@ pub async fn run( GrammarType, ChatRequest, Message, + MessageContent, + MessageChunk, ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, diff --git a/update_doc.py b/update_doc.py index 1ff94a2c..03b5c792 100644 --- a/update_doc.py +++ b/update_doc.py @@ -155,7 +155,7 @@ def check_openapi(check: bool): filename, ], capture_output=True, - ).stdout.decode() + ).stdout.decode("utf-8") os.remove(tmp_filename) if diff: @@ -164,11 +164,25 @@ def check_openapi(check: bool): "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" ) - return True else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - return True + errors = subprocess.run( + [ + "swagger-cli", + # allow for trailing whitespace since it's not significant + # and the precommit hook will remove it + "validate", + filename, + ], + capture_output=True, + ).stderr.decode("utf-8") + if errors: + print(errors) + raise Exception( + f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + ) + return True def main(): From cc4fceb21d3a8856efa7701a82f827198c1eb08e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Jul 2024 17:23:48 +0200 Subject: [PATCH 123/340] Updating the self check (#2209) * Updating the self check * Fix. * Revert the CLI . * cli. * Space. * Revert cargo update. --- docs/openapi.json | 88 ++++++++++++++++++++++++- docs/source/basic_tutorials/launcher.md | 19 +++++- router/src/lib.rs | 2 +- router/src/server.rs | 19 ++++-- update_doc.py | 4 +- 5 files changed, 121 insertions(+), 11 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 2a331c80..113fd8b5 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -492,12 +492,12 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/Completion" + "$ref": "#/components/schemas/CompletionFinal" } }, "text/event-stream": { "schema": { - "$ref": "#/components/schemas/CompletionCompleteChunk" + "$ref": "#/components/schemas/Chunk" } } } @@ -1324,6 +1324,17 @@ } } }, + "FunctionName": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string" + } + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -1764,6 +1775,16 @@ } ] }, + "OutputMessage": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + } + ] + }, "PrefillToken": { "type": "object", "required": [ @@ -1890,6 +1911,23 @@ } } }, + "TextMessage": { + "type": "object", + "required": [ + "role", + "content" + ], + "properties": { + "content": { + "type": "string", + "example": "My name is David and I" + }, + "role": { + "type": "string", + "example": "user" + } + } + }, "Token": { "type": "object", "required": [ @@ -1962,6 +2000,41 @@ } } }, + "ToolCallDelta": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "$ref": "#/components/schemas/DeltaToolCall" + } + } + }, + "ToolCallMessage": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + }, "ToolType": { "oneOf": [ { @@ -1985,6 +2058,17 @@ } ] }, + "Url": { + "type": "object", + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + } + } + }, "Usage": { "type": "object", "required": [ diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 1e5b6fd2..5e40146f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -62,7 +62,9 @@ Options: Possible values: - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from + - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels + - marlin: 4 bit quantization. Requires a specific Marlin quantized model: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model @@ -124,7 +126,7 @@ Options: ## MAX_TOP_N_TOKENS ```shell --max-top-n-tokens - This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking + This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking [env: MAX_TOP_N_TOKENS=] [default: 5] @@ -334,6 +336,13 @@ Options: --otlp-endpoint [env: OTLP_ENDPOINT=] +``` +## OTLP_SERVICE_NAME +```shell + --otlp-service-name + [env: OTLP_SERVICE_NAME=] + [default: text-generation-inference.router] + ``` ## CORS_ALLOW_ORIGIN ```shell @@ -407,6 +416,14 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] +``` +## LORA_ADAPTERS +```shell + --lora-adapters + Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request + + [env: LORA_ADAPTERS=] + ``` ## HELP ```shell diff --git a/router/src/lib.rs b/router/src/lib.rs index 080c029a..f856406d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -848,7 +848,7 @@ pub enum ToolType { Function { function: FunctionName }, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] pub struct FunctionName { pub name: String, } diff --git a/router/src/server.rs b/router/src/server.rs index 8cc09af3..4e5af99c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -11,10 +11,11 @@ use crate::kserve::{ }; use crate::validation::ValidationError; use crate::{ - BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, MessageChunk, MessageContent, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Usage, Validation, + BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, + GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, + HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, + SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse, + ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -562,8 +563,8 @@ request_body = CompletionRequest, responses( (status = 200, description = "Generated Chat Completion", content( -("application/json" = Completion), -("text/event-stream" = CompletionCompleteChunk), +("application/json" = CompletionFinal), +("text/event-stream" = Chunk), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -1448,6 +1449,12 @@ pub async fn run( Message, MessageContent, MessageChunk, + Url, + FunctionName, + OutputMessage, + TextMessage, + ToolCallMessage, + ToolCallDelta, ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, diff --git a/update_doc.py b/update_doc.py index 03b5c792..bfa7e4e9 100644 --- a/update_doc.py +++ b/update_doc.py @@ -177,7 +177,9 @@ def check_openapi(check: bool): ], capture_output=True, ).stderr.decode("utf-8") - if errors: + # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where + # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 + if not errors.startswith("Swagger schema validation failed."): print(errors) raise Exception( f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" From 2a6c3caf1d913cb708b0fa1649a8dcfd45e742ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 9 Jul 2024 20:04:03 +0200 Subject: [PATCH 124/340] Move quantized weight handling out of the `Weights` class (#2194) Quantized weights were loaded in the `Weights` class, but this was getting quite unwieldy, where every higher level method to load weights was a long conditional to cover all the different quantizers. This change moves loading of quantized weights out of the `Weights` class. This is done by defining a simple `WeightsLoader` interface that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`, and `MarlinWeightsLoader`. These implementations are in the quantizers' respective modules. The `Weights` class provides the low-level load operations (such as loading tensors or sharded tensors), but delegates loads that need quantizer-specific weight processing to a loader. The loaders still use the low-level functionality provided by `Weights`. I initially tried making a hierarchy where a class like `GPTQWeights` would inherit from `Weights`. But it is not very flexible (e.g. does not work well with the new weight storage mock used in tests) and the implicit indirections made the code harder to follow. --- server/tests/utils/test_layers.py | 8 +- server/tests/utils/test_weights.py | 162 ++-- server/text_generation_server/layers/exl2.py | 60 ++ .../layers/gptq/__init__.py | 354 ++++++++- .../layers/gptq/quantize.py | 3 + .../text_generation_server/layers/marlin.py | 130 +++- .../layers/tensor_parallel.py | 17 +- .../models/causal_lm.py | 12 +- .../custom_modeling/flash_cohere_modeling.py | 1 - .../custom_modeling/flash_gemma2_modeling.py | 1 - .../custom_modeling/flash_gemma_modeling.py | 1 - .../custom_modeling/flash_gpt2_modeling.py | 7 +- .../custom_modeling/flash_mixtral_modeling.py | 1 - .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 1 - .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 19 +- .../flash_starcoder2_modeling.py | 1 - .../models/flash_causal_lm.py | 11 +- .../text_generation_server/models/idefics.py | 5 + server/text_generation_server/models/mamba.py | 12 +- .../models/seq2seq_lm.py | 5 + .../utils/quantization.py | 119 +++ .../text_generation_server/utils/weights.py | 691 +++--------------- 24 files changed, 896 insertions(+), 731 deletions(-) create mode 100644 server/text_generation_server/utils/quantization.py diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 9a8da0d6..1e3aaf6b 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -2,6 +2,7 @@ import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) +from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: @@ -42,7 +43,12 @@ class Weights: def test_weight_hub_files_offline_error(): vocab_size = 17 - weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) + weights = Weights( + rank=0, + world_size=1, + vocab_size=vocab_size, + hidden_dim=256, + ) embeddings = TensorParallelEmbedding("", weights) input_ids = torch.arange(vocab_size) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 8f88b1f8..36b27be8 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -1,13 +1,47 @@ import pytest import torch -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.marlin import MarlinWeight +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + Weights, + WeightsLoader, +) +from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader +from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader +from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path + +@pytest.fixture +def gptq_weights_loader(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="gptq", + quantize="gptq", + sym=True, + ) + + +@pytest.fixture +def gptq_weights_loader_awq(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="awq", + quantize="awq", + sym=True, + ) + + +@pytest.fixture +def marlin_weights_loader(): + return MarlinWeightsLoader(bits=4, is_marlin_24=False) + + dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -58,7 +92,7 @@ dummy_file_system = { dtype=torch.float32, ), }, - "test_get_multi_weights_row": { + "test_get_weights_row": { "weight.weight": torch.tensor( [ [1, 2], @@ -101,7 +135,7 @@ dummy_file_system = { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), }, - "test_get_multi_weights_row_gptq": { + "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ [1, 2], @@ -200,7 +234,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_exl2": { + "test_get_weights_row_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], @@ -245,7 +279,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_marlin": { + "test_get_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, @@ -308,6 +342,7 @@ class MockWeights(Weights): dummy_fs, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, ): routing = {} self.dummy_fs = dummy_fs @@ -327,6 +362,9 @@ class MockWeights(Weights): self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = ( + DefaultWeightsLoader() if weights_loader is None else weights_loader + ) self._handles = {} def _get_handle(self, filename: Union[Path, str]): @@ -412,12 +450,10 @@ def test_get_weights_col_packed(): ) prefix = "weight" - quantize = None block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size(): ) prefix = "weight" - quantize = None block_sizes = 2 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr(): ) prefix = "weight" - quantize = None block_sizes = [1, 1] w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -519,11 +551,9 @@ def test_get_multi_weights_col(): ) prefixes = ["weight", "weight"] - quantize = None w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -545,10 +575,10 @@ def test_get_multi_weights_col(): ) -def test_get_multi_weights_row(): +def test_get_weights_row(): weights = MockWeights( [ - "test_get_multi_weights_row", + "test_get_weights_row", ], device="cpu", dtype=torch.float32, @@ -557,11 +587,9 @@ def test_get_multi_weights_row(): ) prefix = "weight" - quantize = None - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) assert torch.allclose( @@ -576,7 +604,7 @@ def test_get_multi_weights_row(): # test_get_weights_col -def test_get_weights_col_awq(): +def test_get_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -585,14 +613,13 @@ def test_get_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -617,7 +644,7 @@ def test_get_weights_col_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_gtpq(): +def test_get_weights_col_gtpq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -626,14 +653,13 @@ def test_get_weights_col_gtpq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -664,14 +690,13 @@ def test_get_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) scaled_scale_max = 0.3906 * 256 @@ -692,7 +717,7 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(): +def test_get_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_marlin", @@ -701,14 +726,13 @@ def test_get_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( @@ -723,7 +747,7 @@ def test_get_weights_col_marlin(): # test_get_weights_col_packed -def test_get_weights_col_packed_awq(): +def test_get_weights_col_packed_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_packed_gptq(): +def test_get_weights_col_packed_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(): +def test_get_weights_col_packed_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_marlin", @@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin(): # test_get_multi_weights_col -def test_get_multi_weights_col_awq(): +def test_get_multi_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefixes = ["weight"] - quantize = "awq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" try: w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) except ValueError as e: assert e.args[0] == "get_multi_weights_col is not supported for exl2" -def test_get_multi_weights_col_gptq(): +def test_get_multi_weights_col_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(): +def test_get_multi_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_marlin", @@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin(): assert torch.allclose(w.s, expected_weight.s), "s mismatch" -# test_get_multi_weights_row +# test_get_weights_row -def test_get_multi_weights_row_awq(): +def test_get_weights_row_awq(gptq_weights_loader_awq): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_exl2(): +def test_get_weights_row_exl2(): weights = MockWeights( [ - "test_get_multi_weights_row_exl2", + "test_get_weights_row_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) print(w) @@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_multi_weights_row_gptq(): +def test_get_weights_row_gptq(gptq_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_marlin(): +def test_get_weights_row_marlin(marlin_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_marlin", + "test_get_weights_row_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index f6cb729e..55cba1cc 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,6 +1,9 @@ import torch +from typing import List, Union from dataclasses import dataclass +from text_generation_server.utils.weights import WeightsLoader, Weights + @dataclass class Exl2Weight: @@ -21,3 +24,60 @@ class Exl2Weight: @property def device(self) -> torch.device: return self.q_weight.device + + +class Exl2WeightsLoader(WeightsLoader): + """Loader for exl2-quantized weights.""" + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + raise ValueError("get_multi_weights_col is not supported for exl2") + + def get_weights_row(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 56080145..efcb3118 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,20 +1,14 @@ from dataclasses import dataclass +from loguru import logger import os -from typing import Optional +from typing import List, Optional, Union +from safetensors import SafetensorError +from text_generation_server.utils.weights import Weights, WeightsLoader import torch from text_generation_server.utils.import_utils import ( SYSTEM, ) - - -@dataclass -class GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool +from text_generation_server.utils.log import log_once @dataclass @@ -69,3 +63,341 @@ elif CAN_EXLLAMA: pass from text_generation_server.layers.gptq.quant_linear import QuantLinear + + +class GPTQWeightsLoader(WeightsLoader): + """ + Loader for GPTQ- and AWQ-quantized weights. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + use_exllama = ( + self.bits == 4 + and HAS_EXLLAMA + and self.quantize == "gptq" + and not self.desc_act + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + + if weights.process_group.size() > 1: + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // self.groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + if use_exllama and self.groupsize != -1: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def _get_gptq_params(self, weights: Weights): + try: + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + self.sym = False + self.quant_method = "gptq" + except (SafetensorError, RuntimeError) as e: + pass diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 8d029817..c65d5e78 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional +from text_generation_server.utils.weights import DefaultWeightsLoader + DEV = torch.device("cuda:0") @@ -891,6 +893,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, + weights_loader=DefaultWeightsLoader(), ) hooks = [] for name, module in model.named_modules(): diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a1af67a3..ecb88e76 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +from text_generation_server.utils.weights import Weights, WeightsLoader import torch import torch.nn as nn -from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.import_utils import SYSTEM try: @@ -24,16 +24,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 -def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: +class MarlinWeightsLoader(WeightsLoader): + """Loader for Marlin-quantized weights.""" + + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_row(self, weights: Weights, prefix: str): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: return ( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 and quantize == "gptq" - and gptq_params.quant_method == "gptq" - and gptq_params.bits in GPTQ_MARLIN_BITS - and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES - and gptq_params.sym + and quant_method == "gptq" + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + and sym ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 038de258..011f105b 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") except: # ...otherwise they are quantized. - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 elif weights.process_group.size() > 1: try: @@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_gate_up( - prefix, quantize=config.quantize - ) + weight = weights.get_weights_col_packed_gate_up(prefix) if bias: raise NotImplementedError("packed_gate_up only implemented without bias") else: @@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv( prefix, - quantize=config.quantize, num_heads=num_heads, num_key_value_heads=num_key_value_heads, ) @@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: @@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer): if config.quantize == "exl2": linears = [] for prefix in prefixes: - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None linears.append(get_linear(weight, b, config.quantize)) linear = LayerConcat(linears) else: - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) + weight = weights.get_multi_weights_col(prefixes, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) @@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 868a3cc0..0ea82b1e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -546,12 +547,17 @@ class CausalLM(Model): tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index f993fe72..25719b99 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index beff08b3..a3ce5521 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 14b62b00..34a7efa2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d5dc25cf..cbfcb1b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights): # Weights weight = weights.get_weights_col_packed_qkv( f"{prefix}.c_attn", - config.quantize, config.num_attention_heads, config.num_attention_heads, ) @@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool): """load_row, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) else: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T @@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool): """load_col, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=1 - ) + weight = weights.get_multi_weights_col([prefix], dim=1) else: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 429793ea..49c0e903 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0eca181b..85dcb2a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) + weight = weights.get_multi_weights_col([prefix], dim=0) if isinstance(weight, torch.Tensor): # Only on non quantized versions weight = ( diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7401bc27..6c508264 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4813e2df..65b40fed 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -23,7 +23,7 @@ from text_generation_server.layers.attention import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 21a22046..77b9d49c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -81,11 +82,13 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - gptq_params = weights._get_gptq_params() - if gptq_params.quant_method == "gptq": + loader = weights.weights_loader + assert isinstance(loader, GPTQWeightsLoader) + loader._get_gptq_params(weights) + if loader.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - elif gptq_params.quant_method == "awq": + elif loader.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, @@ -100,8 +103,8 @@ def _load_multi_mqa_gptq( qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, + bits=loader.bits, + groupsize=loader.groupsize, use_exllama=HAS_EXLLAMA, ) @@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=0 - ) + weight = weights.get_multi_weights_col([prefix], dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 2b346283..19556f78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf1fda4a..2ca9eef3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -50,6 +50,7 @@ from text_generation_server.models.globals import ( from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( @@ -881,12 +882,16 @@ class FlashCausalLM(Model): torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader(quantize, model_id, revision) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device, dtype, process_group=self.process_group, aliases=aliases + filenames, + device, + dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index f2955bd0..0deab6ce 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -23,6 +23,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.quantization import get_loader class IDEFICSSharded(IdeficsCausalLM): @@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM): device=device, dtype=dtype, process_group=self.process_group, + weights_loader=weights_loader, ) model = IdeficsForVisionText2Text(config, weights) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9189b45c..4ed9722c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -28,6 +28,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -448,8 +449,17 @@ class Mamba(Model): config.quantize = quantize config.speculator = speculator torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index dbaf1253..fa8b5025 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -18,6 +18,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -586,6 +587,9 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = config.decoder_start_token_id + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -594,6 +598,7 @@ class Seq2SeqLM(Model): dtype=dtype, process_group=self.process_group, aliases=aliases, + weights_loader=weights_loader, ) if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py new file mode 100644 index 00000000..07975bea --- /dev/null +++ b/server/text_generation_server/utils/quantization.py @@ -0,0 +1,119 @@ +from typing import Optional +import os +import json +from dataclasses import dataclass + +from huggingface_hub import hf_hub_download + +from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader + + +@dataclass +class _QuantizerConfig: + bits: int + checkpoint_format: Optional[str] + desc_act: bool + groupsize: int + quant_method: str + sym: bool + + +# We should probably do this with Pytantic JSON deserialization, +# but for now we'll stay close to the old _set_gptq_params. +def _get_quantizer_config(model_id, revision): + bits = 4 + groupsize = -1 + quant_method = "gptq" + checkpoint_format = None + sym = True + desc_act = False + + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + data = json.load(f) + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] + # Order is important here, desc_act is missing on some real models + quant_method = data["quantization_config"]["quant_method"] + checkpoint_format = data["quantization_config"].get("checkpoint_format") + sym = data["quantization_config"]["sym"] + desc_act = data["quantization_config"]["desc_act"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["bits"] + groupsize = data["group_size"] + sym = data["sym"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["w_bit"] + groupsize = data["q_group_size"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + pass + + return _QuantizerConfig( + bits=bits, + groupsize=groupsize, + quant_method=quant_method, + checkpoint_format=checkpoint_format, + sym=sym, + desc_act=desc_act, + ) + + +def get_loader( + quantize: Optional[str], model_id: str, revision: Optional[str] +) -> WeightsLoader: + quantizer_config = _get_quantizer_config(model_id, revision) + if quantize in {"awq", "gptq"}: + from text_generation_server.layers.gptq import GPTQWeightsLoader + + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2WeightsLoader + + return Exl2WeightsLoader() + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeightsLoader + + return MarlinWeightsLoader( + bits=quantizer_config.bits, + is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", + ) + else: + return DefaultWeightsLoader() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3731fd24..1a62fb3b 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,13 +1,88 @@ -import os +from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional, Union -from safetensors import safe_open, SafetensorError +from safetensors import safe_open import torch -from loguru import logger -from huggingface_hub import hf_hub_download -import json -from text_generation_server.layers.gptq import GPTQParams -from text_generation_server.utils.log import log_once + + +class WeightsLoader(ABC): + """ + Instances of this type implement higher-level weight loading. + + At a low-level, every weight is stored in the Safetensors format. + The interpretation of weights may be different however, for instance + could be packed, quantized weights. Loaders are responsible for + interpreting the raw tensors, sharding tensors in a manner compatible + with the format, etc. + """ + + @abstractmethod + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + """ + Get the packed weights at the given prefix with column-splitting for + tensor parallelism. This method should be used when multiple different + weights are packed into a tensor, for instance, query/key/value + weights or a gate/up projection. + + The `block_sizes` determines the proportions of the packed tensors. + The columns are split in equally sized blocks when `block_sizes` is an + `int`, or in blocks proportional given to the sizes. For instance + `[2, 1, 1]` will divide an input with dimensionality `1024` in + `[512, 256, 256]`. + """ + ... + + def get_weights_col(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply column-splitting for tensor + paralllism. + """ + return weights.get_multi_weights_col([prefix], 0) + + @abstractmethod + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + + @abstractmethod + def get_weights_row(self, weights: "Weights", prefix: str): + """ + Get the weights at the given prefix and apply row-splitting for tensor + parallism. + """ + ... + + +class DefaultWeightsLoader(WeightsLoader): + """ + Loader that uses tensors as-is with the exception of applying sharding + and/or concatenation. + """ + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + return weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + return torch.cat(w, dim=dim) + + def get_weights_row(self, weights: "Weights", prefix: str): + return weights.get_sharded(f"{prefix}.weight", dim=1) class Weights: @@ -17,6 +92,7 @@ class Weights: device, dtype, process_group, + weights_loader: WeightsLoader, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, ): @@ -37,6 +113,7 @@ class Weights: self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = weights_loader self._handles = {} def _get_handle(self, filename): @@ -181,295 +258,27 @@ class Weights: num_key_value_heads: int, ): return self.get_weights_col_packed( - prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + prefix, [num_heads, num_key_value_heads, num_key_value_heads] ) - def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 2) + def get_weights_col_packed_gate_up(self, prefix: str): + return self.get_weights_col_packed(prefix, 2) - def get_weights_col_packed( - self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] - ): + def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]): """ - Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor. - The columns are split in equally sized blocks when blocks is an `int`, or in blocks proportional given to the sizes. For instance `[2, 1, 1]` will divide an input with dimensionality `1024` in `[512, 256, 256]`. This is convenient for e.g. splitting QKV without knowing the storage details of quantized weights. """ - if quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) + return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes) - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized." - ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) + def get_weights_col(self, prefix: str): + return self.weights_loader.get_weights_col(self, prefix) - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - g_idx = self.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = self.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - if quantize == "gptq" and gptq_params.quant_method == "gptq": - g_idx = self.get_tensor(f"{prefix}.g_idx") - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=False, - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - repack_gptq_for_marlin, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - B = self.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = self.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - B = self.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - return weight - - def get_weights_col(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - return self.get_multi_weights_col([prefix], quantize, 0) - - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "exl2": - raise ValueError("get_multi_weights_col is not supported for exl2") - elif quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - gptq_params.bits == 4 - and HAS_EXLLAMA - and quantize == "gptq" - and not gptq_params.desc_act - ) - - if quantize == "gptq" and gptq_params.quant_method == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - else: - w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) - - return weight + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) def get_tensor_shard(self, var, dim): world_size = self.process_group.size() @@ -487,318 +296,8 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - elif quantize == "gptq": - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) - - use_exllama = True - if gptq_params.bits != 4: - use_exllama = False - - if gptq_params.desc_act: - log_once(logger.warning, "Disabling exllama because desc_act=True") - use_exllama = False - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - - if gptq_params.quant_method == "gptq": - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - elif gptq_params.quant_method == "awq": - g_idx = None - - if self.process_group.size() > 1: - if g_idx is not None: - if ( - not torch.equal( - g_idx.cpu(), - torch.tensor( - [ - i // gptq_params.groupsize - for i in range(g_idx.shape[0]) - ], - dtype=torch.int32, - ), - ) - and not (g_idx == 0).all() - ): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_exllama = False - - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and gptq_params.groupsize != -1: - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - - if gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = None - use_exllama = False - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = self.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) - return weight - - def _get_gptq_params(self) -> GPTQParams: - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = False - sym = False - quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e - - return GPTQParams( - bits=bits, - checkpoint_format=checkpoint_format, - desc_act=desc_act, - groupsize=groupsize, - quant_method=quant_method, - sym=sym, - ) - - def _set_gptq_params(self, model_id, revision): - filename = "config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["quantization_config"]["bits"] - self.gptq_groupsize = data["quantization_config"]["group_size"] - # Order is important here, desc_act is missing on some real models - self.quant_method = data["quantization_config"]["quant_method"] - self.gptq_checkpoint_format = data["quantization_config"].get( - "checkpoint_format" - ) - self.gptq_sym = data["quantization_config"]["sym"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] - except Exception: - filename = "quantize_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["bits"] - self.gptq_groupsize = data["group_size"] - self.gptq_sym = data["sym"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - filename = "quant_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["w_bit"] - self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - pass + def get_weights_row(self, prefix: str): + return self.weights_loader.get_weights_row(self, prefix) def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: From 85c3c5d64fd01b754f0ab94a5e358b4e1bdcaf6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 11 Jul 2024 16:03:26 +0200 Subject: [PATCH 125/340] Add support for FP8 on compute capability >=8.0, <8.9 (#2213) Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs with compute capability >=8.0 and <8.9. Co-authored-by: Florian Zimmermeister --- server/marlin/marlin_kernels/__init__.pyi | 15 + server/marlin/marlin_kernels/ext.cpp | 2 + server/marlin/marlin_kernels/ext.hh | 5 + server/marlin/marlin_kernels/fp8_marlin.cu | 1308 +++++++++++++++++ server/marlin/setup.py | 1 + server/text_generation_server/layers/fp8.py | 19 + .../text_generation_server/layers/linear.py | 4 +- .../text_generation_server/layers/marlin.py | 115 +- 8 files changed, 1465 insertions(+), 4 deletions(-) create mode 100644 server/marlin/marlin_kernels/fp8_marlin.cu diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi index 663984d0..53464719 100644 --- a/server/marlin/marlin_kernels/__init__.pyi +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -59,3 +59,18 @@ def marlin_gemm( Matrix multiplication using Marlin kernels. """ ... + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return torch.ops._C.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 37eccef6..04e1530f 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gptq_marlin_repack", &gptq_marlin_repack, "Repack GPTQ parameters for Marlin"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); + // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. + m.def("fp8_marlin_gemm", &fp8_marlin_gemm); } diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index d1caaab7..102c058e 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &workspace, int64_t size_m, int64_t size_n, int64_t size_k); +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k); + #endif diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu new file mode 100644 index 00000000..aaef67e5 --- /dev/null +++ b/server/marlin/marlin_kernels/fp8_marlin.cu @@ -0,0 +1,1308 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "./gptq_marlin.cuh" +#include "./gptq_marlin_dtypes.cuh" + +using namespace gptq_marlin; + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace fp8_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + int slice_k_start = tb_k * slice_row; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We scale a `half2` tile in row-major layout for column-wise quantization. + int s_sh_rd = + 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + + thread_block_reduce(); + + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + + start_pipes(); + } + } + } +} + + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ + locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, + int group_size) { + int tb_n = th_config.thread_n; + + // Get max scale groups per thread-block + // Fixed for channelwise + int tb_groups = 1; + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, + prob_k, num_bits, group_size); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) + +template +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, + group_size, max_shared_mem); + } + + TORCH_CHECK( + exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = -1; + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + int* locks = (int*)workspace; + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + // Channelwise only for FP8 + TORCH_CHECK(b_scales.size(0) == 1) + num_groups = b_scales.size(0); + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), size_m, + size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, + dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/server/marlin/setup.py b/server/marlin/setup.py index aed84e9e..cc38bccf 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -9,6 +9,7 @@ setup( CUDAExtension( name="marlin_kernels", sources=[ + "marlin_kernels/fp8_marlin.cu", "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/marlin_cuda_kernel.cu", diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index dd61d081..b76af8f1 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,4 +1,23 @@ +from enum import Enum, auto + import torch +from text_generation_server.utils.import_utils import SYSTEM + + +def get_fp8_linear() -> torch.nn.Module: + """ + Return an FP8 linear `Module` that is compatible with the current system. + """ + + if SYSTEM == "cuda": + major, minor = torch.cuda.get_device_capability() + if major == 8 and minor < 9: + from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + + return GPTQMarlinFP8Linear + + # On other systems let Torch decide if the hardware supports FP8. + return Fp8Linear def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index e94e5465..babd86b0 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize): "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Linear + from text_generation_server.layers.fp8 import get_fp8_linear - linear = Fp8Linear(weight, bias) + linear = get_fp8_linear()(weight, bias) elif quantize == "bitsandbytes": try: from text_generation_server.layers.bnb import ( diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index ecb88e76..9777a47e 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union -from text_generation_server.utils.weights import Weights, WeightsLoader import torch import torch.nn as nn - +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights, WeightsLoader try: import marlin_kernels @@ -455,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module): return C +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + qweight, scale = fp8_quantize(weight) + scale = scale.to(torch.float16) + qweight, scales = repack_fp8_for_marlin(qweight, scale) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = scale.reshape(1, 1).repeat(1, out_features) + scales = permute_scales(scales) + + return repacked, scales + + @dataclass class MarlinWeight: """ From 5029e7215c6710c087f3b90fe87726d42b6f6325 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 11 Jul 2024 10:42:58 -0400 Subject: [PATCH 126/340] fix: append DONE message to chat stream (#2221) * fix: append DONE message to chat stream * fix: update completions endpoint --- router/src/server.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/router/src/server.rs b/router/src/server.rs index 4e5af99c..d3a280ca 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -812,6 +812,10 @@ async fn completions( } }; + let stream = stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1171,6 +1175,11 @@ async fn chat_completions( span, ) .await; + + let response_stream = response_stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { From dedeb3cfa04bb54f0fb0d86a0a2753b17b1719b4 Mon Sep 17 00:00:00 2001 From: SeongBeomLEE <2712qwer@gmail.com> Date: Fri, 12 Jul 2024 17:04:51 +0900 Subject: [PATCH 127/340] Modifying base in yarn embedding (#2212) --- server/text_generation_server/layers/rotary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 87a61e82..8c354b82 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module): max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], - base=10000.0, + base=base, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, From ee562660443511db7f055acefd2f67b0fb11ecf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 12 Jul 2024 12:20:12 +0200 Subject: [PATCH 128/340] Use symmetric quantization in the `quantize` subcommand (#2120) Packing of asymmetric quantization is broken, all (q)zeros values of `0` get reset to `1`, resulting in a loss of accuracy. So instead use symmetric quantization. To be able to distinguish models with symmetric and asymmetric quantization, a new config tensor `gptq_sym` is added. If this tensor is not present, we assume `sym=False`. --- server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/gptq/__init__.py | 12 ++++++++---- .../text_generation_server/layers/gptq/quantize.py | 3 +++ server/text_generation_server/utils/weights.py | 7 +++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd..71ad18f7 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -353,6 +353,7 @@ def quantize( upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index efcb3118..aaa7a68a 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -393,11 +393,15 @@ class GPTQWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - try: + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False - self.sym = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) self.quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - pass diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index c65d5e78..0271d913 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -871,6 +871,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -946,6 +947,7 @@ def quantize( percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -957,6 +959,7 @@ def quantize( state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 1a62fb3b..50a9167a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -146,6 +146,13 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() From 619eeded478af6a813e0b8168a583d2f696d92bb Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 15 Jul 2024 09:16:15 -0400 Subject: [PATCH 129/340] feat: simple mistral lora integration tests (#2180) * feat: simple mistral lora integration tests * fix: include args in docker launcher * fix: disable cuda graphs with lora and warn * fix: adjust docs and precommit issues * fix: re update docs --- integration-tests/conftest.py | 18 ++ ...mistral_with_customer_support_adapter.json | 251 ++++++++++++++++++ ...est_lora_mistral_with_dbpedia_adapter.json | 53 ++++ .../test_lora_mistral_without_adapter.json | 251 ++++++++++++++++++ ...tral_without_customer_support_adapter.json | 251 ++++++++++++++++++ integration-tests/models/test_lora_mistral.py | 134 ++++++++++ server/text_generation_server/cli.py | 9 + 7 files changed, 967 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json create mode 100644 integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json create mode 100644 integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json create mode 100644 integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json create mode 100644 integration-tests/models/test_lora_mistral.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index f5f38ac6..60146ad1 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -333,6 +333,8 @@ def launcher(event_loop): max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, + lora_adapters: Optional[List[str]] = None, + cuda_graphs: Optional[List[int]] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -379,6 +381,14 @@ def launcher(event_loop): if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) + if lora_adapters: + args.append("--lora-adapters") + args.append(",".join(lora_adapters)) + if cuda_graphs: + args.append("--cuda-graphs") + args.append(",".join(map(str, cuda_graphs))) + + print(" ".join(args), file=sys.stderr) env["LOG_LEVEL"] = "info,text_generation_router=debug" @@ -418,6 +428,8 @@ def launcher(event_loop): max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, + lora_adapters: Optional[List[str]] = None, + cuda_graphs: Optional[List[int]] = None, ): port = random.randint(8000, 10_000) @@ -447,6 +459,12 @@ def launcher(event_loop): if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) + if lora_adapters: + args.append("--lora-adapters") + args.append(",".join(lora_adapters)) + if cuda_graphs: + args.append("--cuda-graphs") + args.append(",".join(map(str, cuda_graphs))) client = docker.from_env() diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json new file mode 100644 index 00000000..dfdd2cc3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.27416992, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.17016602, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -2.7109375, + "special": false, + "text": "I" + }, + { + "id": 28809, + "logprob": -1.5, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.34204102, + "special": false, + "text": "m" + }, + { + "id": 459, + "logprob": -1.6914062, + "special": false, + "text": " not" + }, + { + "id": 1864, + "logprob": -0.69140625, + "special": false, + "text": " sure" + }, + { + "id": 513, + "logprob": -1.6171875, + "special": false, + "text": " if" + }, + { + "id": 315, + "logprob": -1.3837891, + "special": false, + "text": " I" + }, + { + "id": 541, + "logprob": -1.2226562, + "special": false, + "text": " can" + }, + { + "id": 1567, + "logprob": -1.8652344, + "special": false, + "text": " come" + }, + { + "id": 582, + "logprob": -0.0070228577, + "special": false, + "text": " up" + }, + { + "id": 395, + "logprob": -0.0054092407, + "special": false, + "text": " with" + }, + { + "id": 28705, + "logprob": -0.62597656, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.0035572052, + "special": false, + "text": "3" + }, + { + "id": 4842, + "logprob": -0.93603516, + "special": false, + "text": " unique" + }, + { + "id": 3085, + "logprob": -0.028411865, + "special": false, + "text": " words" + }, + { + "id": 369, + "logprob": -1.0400391, + "special": false, + "text": " that" + }, + { + "id": 6685, + "logprob": -0.09710693, + "special": false, + "text": " describe" + }, + { + "id": 528, + "logprob": -0.066467285, + "special": false, + "text": " me" + }, + { + "id": 28725, + "logprob": -1.0722656, + "special": false, + "text": "," + }, + { + "id": 562, + "logprob": -0.33422852, + "special": false, + "text": " but" + }, + { + "id": 315, + "logprob": -0.5136719, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -0.8989258, + "special": false, + "text": "’" + }, + { + "id": 584, + "logprob": -0.2076416, + "special": false, + "text": "ll" + }, + { + "id": 1464, + "logprob": -0.8808594, + "special": false, + "text": " try" + }, + { + "id": 28723, + "logprob": -0.88427734, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.91064453, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.08105469, + "special": false, + "text": "\n" + }, + { + "id": 28740, + "logprob": -1.8486328, + "special": false, + "text": "1" + }, + { + "id": 28723, + "logprob": -0.111572266, + "special": false, + "text": "." + }, + { + "id": 23626, + "logprob": -3.15625, + "special": false, + "text": " Creative" + }, + { + "id": 13, + "logprob": -0.9194336, + "special": false, + "text": "\n" + }, + { + "id": 28750, + "logprob": -0.24841309, + "special": false, + "text": "2" + }, + { + "id": 28723, + "logprob": -9.393692e-05, + "special": false, + "text": "." + }, + { + "id": 6785, + "logprob": -3.1386719, + "special": false, + "text": " Fun" + }, + { + "id": 1780, + "logprob": -0.53564453, + "special": false, + "text": "ny" + }, + { + "id": 13, + "logprob": -0.09033203, + "special": false, + "text": "\n" + }, + { + "id": 28770, + "logprob": -0.00466156, + "special": false, + "text": "3" + }, + { + "id": 28723, + "logprob": -0.00016450882, + "special": false, + "text": "." + } + ] + }, + "generated_text": "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3." +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json new file mode 100644 index 00000000..91eb5edf --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json @@ -0,0 +1,53 @@ +{ + "details": { + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 1, + "logprob": -0.49658203, + "special": true, + "text": "" + }, + { + "id": 28705, + "logprob": -0.0016384125, + "special": false, + "text": " " + }, + { + "id": 1, + "logprob": -1.4931641, + "special": true, + "text": "" + }, + { + "id": 28705, + "logprob": -0.00075769424, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -0.25024414, + "special": false, + "text": "1" + }, + { + "id": 28740, + "logprob": -0.2631836, + "special": false, + "text": "1" + }, + { + "id": 2, + "logprob": -0.0003285408, + "special": true, + "text": "" + } + ] + }, + "generated_text": " 11" +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json new file mode 100644 index 00000000..13018688 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.0488281, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.0800781, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -2.1152344, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -1.6748047, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -0.097229004, + "special": false, + "text": "1" + }, + { + "id": 28723, + "logprob": -0.16467285, + "special": false, + "text": "." + }, + { + "id": 7615, + "logprob": -2.2246094, + "special": false, + "text": " News" + }, + { + "id": 13, + "logprob": -1.0488281, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.69189453, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.013343811, + "special": false, + "text": " " + }, + { + "id": 28750, + "logprob": -0.011230469, + "special": false, + "text": "2" + }, + { + "id": 28723, + "logprob": -0.00096845627, + "special": false, + "text": "." + }, + { + "id": 21095, + "logprob": -2.5605469, + "special": false, + "text": " Blog" + }, + { + "id": 13, + "logprob": -0.19458008, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.031280518, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.0030708313, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.0029277802, + "special": false, + "text": "3" + }, + { + "id": 28723, + "logprob": -0.0012350082, + "special": false, + "text": "." + }, + { + "id": 20108, + "logprob": -2.1582031, + "special": false, + "text": " Article" + }, + { + "id": 13, + "logprob": -0.05810547, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.35083008, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.034332275, + "special": false, + "text": " " + }, + { + "id": 28781, + "logprob": -0.009666443, + "special": false, + "text": "4" + }, + { + "id": 28723, + "logprob": -0.0013113022, + "special": false, + "text": "." + }, + { + "id": 8349, + "logprob": -2.6191406, + "special": false, + "text": " Review" + }, + { + "id": 13, + "logprob": -0.04031372, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.45239258, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.045410156, + "special": false, + "text": " " + }, + { + "id": 28782, + "logprob": -0.0041236877, + "special": false, + "text": "5" + }, + { + "id": 28723, + "logprob": -0.0010223389, + "special": false, + "text": "." + }, + { + "id": 5299, + "logprob": -2.8066406, + "special": false, + "text": " Other" + }, + { + "id": 13, + "logprob": -0.12054443, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.44580078, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.4921875, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.3574219, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.0039062, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.5859375, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.43481445, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.2783203, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20410156, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json new file mode 100644 index 00000000..8c00dee7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.31347656, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.27441406, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -2.2285156, + "special": false, + "text": "I" + }, + { + "id": 28809, + "logprob": -1.4677734, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.31762695, + "special": false, + "text": "m" + }, + { + "id": 264, + "logprob": -1.6865234, + "special": false, + "text": " a" + }, + { + "id": 1215, + "logprob": -3.2695312, + "special": false, + "text": " very" + }, + { + "id": 20640, + "logprob": -3.1230469, + "special": false, + "text": " passionate" + }, + { + "id": 1338, + "logprob": -0.48339844, + "special": false, + "text": " person" + }, + { + "id": 28723, + "logprob": -0.9970703, + "special": false, + "text": "." + }, + { + "id": 315, + "logprob": -0.5498047, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -1.1923828, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.080444336, + "special": false, + "text": "m" + }, + { + "id": 1215, + "logprob": -1.8271484, + "special": false, + "text": " very" + }, + { + "id": 12215, + "logprob": -2.8847656, + "special": false, + "text": " driven" + }, + { + "id": 28723, + "logprob": -1.0927734, + "special": false, + "text": "." + }, + { + "id": 315, + "logprob": -0.4584961, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -0.5019531, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.030715942, + "special": false, + "text": "m" + }, + { + "id": 1215, + "logprob": -0.96972656, + "special": false, + "text": " very" + }, + { + "id": 7798, + "logprob": -2.8847656, + "special": false, + "text": " determined" + }, + { + "id": 28723, + "logprob": -0.27319336, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.56396484, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.011016846, + "special": false, + "text": "\n" + }, + { + "id": 3195, + "logprob": -0.7163086, + "special": false, + "text": "What" + }, + { + "id": 349, + "logprob": -1.1611328, + "special": false, + "text": " is" + }, + { + "id": 574, + "logprob": -0.515625, + "special": false, + "text": " your" + }, + { + "id": 6656, + "logprob": -1.0253906, + "special": false, + "text": " favorite" + }, + { + "id": 1970, + "logprob": -2.1738281, + "special": false, + "text": " thing" + }, + { + "id": 684, + "logprob": -0.48364258, + "special": false, + "text": " about" + }, + { + "id": 1250, + "logprob": -1.8876953, + "special": false, + "text": " being" + }, + { + "id": 264, + "logprob": -0.41967773, + "special": false, + "text": " a" + }, + { + "id": 8626, + "logprob": -2.9160156, + "special": false, + "text": " teacher" + }, + { + "id": 28804, + "logprob": -0.11920166, + "special": false, + "text": "?" + }, + { + "id": 13, + "logprob": -0.023727417, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.010848999, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -1.0566406, + "special": false, + "text": "I" + }, + { + "id": 2016, + "logprob": -0.7163086, + "special": false, + "text": " love" + }, + { + "id": 272, + "logprob": -1.9169922, + "special": false, + "text": " the" + }, + { + "id": 1639, + "logprob": -2.03125, + "special": false, + "text": " fact" + } + ] + }, + "generated_text": "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact" +} diff --git a/integration-tests/models/test_lora_mistral.py b/integration-tests/models/test_lora_mistral.py new file mode 100644 index 00000000..ccdc1486 --- /dev/null +++ b/integration-tests/models/test_lora_mistral.py @@ -0,0 +1,134 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def lora_mistral_handle(launcher): + with launcher( + "mistralai/Mistral-7B-v0.1", + lora_adapters=[ + "predibase/dbpedia", + "predibase/customer_support", + ], + cuda_graphs=[0], + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def lora_mistral(lora_mistral_handle): + await lora_mistral_handle.health(300) + return lora_mistral_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral(lora_mistral, response_snapshot): + response = await lora_mistral.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + assert response.details.generated_tokens == 10 + + +classification_prompt = """You are given the title and the body of an article below. Please determine the type of the article.\n### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\n\n### Article Type:""" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_without_adapter(lora_mistral, response_snapshot): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": classification_prompt, + "parameters": { + "max_new_tokens": 40, + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n" + ) + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": classification_prompt, + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/dbpedia", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generated_text"] == " 11" + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_with_customer_support_adapter( + lora_mistral, response_snapshot +): + print(lora_mistral.base_url) + print(lora_mistral.headers) + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": "What are 3 unique words that describe you?", + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/customer_support", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3." + ) + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_without_customer_support_adapter( + lora_mistral, response_snapshot +): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": "What are 3 unique words that describe you?", + "parameters": { + "max_new_tokens": 40, + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact" + ) + assert data == response_snapshot diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 71ad18f7..eab64831 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -91,6 +91,15 @@ def serve( f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." ) + # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled + # and warn the user + if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: + logger.warning( + f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." + ) + global CUDA_GRAPHS + CUDA_GRAPHS = None + # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value From 271ebb7e20c09b83bd250b17d28414172433da3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Mon, 15 Jul 2024 15:17:13 +0200 Subject: [PATCH 130/340] fix custom cache dir (#2226) * fix to not ignore HUGGINGFACE_HUB_CACHE in cache * delete printlns * delete newlines * maybe fix trailing whitespace --- router/src/main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/router/src/main.rs b/router/src/main.rs index 21cd6649..b060d73c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -210,7 +210,11 @@ async fn main() -> Result<(), RouterError> { } let api = if use_api { if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { - let cache = Cache::default(); + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); Type::Cache(cache) } else { From 8a223eb6ac9a8e9df36bcc900db4c1f1b2ae8961 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Mon, 15 Jul 2024 15:34:20 +0200 Subject: [PATCH 131/340] fix: Remove bitsandbytes installation when running cpu-only install (#2216) Remove bitsandbytes installation when running cpu-only install --- server/Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/Makefile b/server/Makefile index d701c819..33940655 100644 --- a/server/Makefile +++ b/server/Makefile @@ -21,13 +21,14 @@ gen-server: install-server: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt - pip install -e ".[bnb, accelerate, quantize, peft, outlines]" + pip install -e ".[accelerate, quantize, peft, outlines]" install: install-cuda echo "Installed server" install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention + pip install -e ".[bnb]" install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm From e955f7b536f88e6628732e285b7fc3a4fdca13aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 16 Jul 2024 07:58:25 +0200 Subject: [PATCH 132/340] Add support for AWQ-quantized Idefics2 (#2233) Fixes #2036. --- .../models/custom_modeling/idefics2.py | 35 +++++++++++++------ .../text_generation_server/utils/weights.py | 15 ++++++++ 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a83bc1c6..daf3329a 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -34,6 +34,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, ) +from text_generation_server.utils.weights import DefaultWeightsLoader def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module): class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - config.vision_config.quantize = config.quantize + config.vision_config.quantize = None config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator @@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module): name="text_model", ) self.dtype = weights.dtype - self.vision_model = Idefics2VisionTransformer( - prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", - config=vision_config, - weights=weights, - ) - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader()): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + quantize = config.quantize + try: + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + finally: + config.quantize = quantize + self.config = config self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_token_id = config.image_token_id diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 50a9167a..1c55dd74 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from contextlib import contextmanager from pathlib import Path from typing import Dict, List, Optional, Union from safetensors import safe_open @@ -306,6 +307,20 @@ class Weights: def get_weights_row(self, prefix: str): return self.weights_loader.get_weights_row(self, prefix) + @contextmanager + def use_loader(self, weights_loader: WeightsLoader): + """ + This method is a context manager that can be used to use `Weights` with + a different loader for the duration of the context. + """ + + old_loader = self.weights_loader + self.weights_loader = weights_loader + try: + yield + finally: + self.weights_loader = old_loader + def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: """ From 7177da0df6b910018c91c3d4f1baf819b8c42e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 16 Jul 2024 08:36:05 +0200 Subject: [PATCH 133/340] `server quantize`: expose groupsize option (#2225) --- server/text_generation_server/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index eab64831..fe839cf4 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -341,6 +341,7 @@ def quantize( upload_to_model_id: Optional[str] = None, percdamp: float = 0.01, act_order: bool = False, + groupsize: int = 128, ): if revision is None: revision = "main" @@ -355,7 +356,7 @@ def quantize( quantize( model_id=model_id, bits=4, - groupsize=128, + groupsize=groupsize, output_dir=output_dir, revision=revision, trust_remote_code=trust_remote_code, From e0710ccbeb71fd761c700a498ef6059fbef561f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 16 Jul 2024 09:30:57 +0200 Subject: [PATCH 134/340] Remove stray `quantize` argument in `get_weights_col_packed_qkv` (#2237) Fixes #2236. --- server/text_generation_server/utils/weights.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 1c55dd74..b530af23 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -261,7 +261,6 @@ class Weights: def get_weights_col_packed_qkv( self, prefix: str, - quantize: str, num_heads: int, num_key_value_heads: int, ): From 118ee57f82e67fd3da445bd49e4e6782eefd1505 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 18 Jul 2024 14:00:13 +0000 Subject: [PATCH 135/340] fix(server): fix cohere (#2249) --- .../models/custom_modeling/flash_cohere_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 25719b99..49f5a81b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -259,8 +259,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - input_lengths, slots, + input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) From 2dd680b799bc321854e19ca475a83d8a8022fd66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 09:37:39 +0200 Subject: [PATCH 136/340] Improve the handling of quantized weights (#2250) * Improve the handling of quantized weights Handling of quantized weights was split between two mechanisms: - For quantized checkpoints, we used the new weight loader infrastructure. - For quantization while loading (EETQ, FP8, bitsandbytes) we instead relied on conditional in `get_linear`. Weight loaders support context managers to selectively load particular layers with different weight loaders, which is useful for models like Idefics2 AWQ, which uses a quantized text model, but unquantized vision and connector models. However, the context manager would be overrided by `get_linear`, which string-checks `quantizer`. Also, the context manager would not work with EETQ, FP8, and bitsandbytes. This change migrates all quantizers to the weight loader infrastructure. This has several benefits: - We can use context managers with all quantizers. - All the implementation details move down to the quantizer layers, `get_linear` does not need to know how to handle quantizer linear layers. - All quantizer weights are strongly typed, we don't pass around raw tensors. - We don't have to pass around the `quantizer` string everywhere. * Exclude non-MLP layers when using FP8 quantization with Llama --- .../test_flash_llama_marlin.json | 89 +++++ .../test_flash_llama_marlin24_all_params.json | 89 +++++ .../test_flash_llama_marlin24_load.json | 358 ++++++++++++++++++ .../models/test_completion_prompts.py | 2 + .../models/test_flash_llama_marlin_24.py | 66 ++++ server/tests/utils/test_weights.py | 22 +- server/text_generation_server/layers/bnb.py | 31 +- server/text_generation_server/layers/eetq.py | 18 + server/text_generation_server/layers/exl2.py | 13 +- server/text_generation_server/layers/fp8.py | 11 +- .../layers/gptq/__init__.py | 68 +++- .../text_generation_server/layers/linear.py | 172 +-------- .../text_generation_server/layers/marlin.py | 27 +- .../layers/tensor_parallel.py | 14 +- .../custom_modeling/flash_cohere_modeling.py | 4 +- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../custom_modeling/flash_gemma2_modeling.py | 4 +- .../custom_modeling/flash_gemma_modeling.py | 4 +- .../custom_modeling/flash_gpt2_modeling.py | 8 +- .../custom_modeling/flash_llama_modeling.py | 62 ++- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 4 +- .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 9 +- .../flash_starcoder2_modeling.py | 4 +- .../models/custom_modeling/idefics2.py | 20 +- .../models/custom_modeling/mpt_modeling.py | 2 +- .../utils/quantization.py | 35 +- .../text_generation_server/utils/weights.py | 51 ++- 30 files changed, 936 insertions(+), 265 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json create mode 100644 integration-tests/models/test_flash_llama_marlin_24.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json new file mode 100644 index 00000000..94883de5 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json new file mode 100644 index 00000000..58cacb80 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -0.6645508, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 6527, + "logprob": -2.2324219, + "special": false, + "text": " Could" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " not" + }, + { + "id": 6088, + "logprob": -1.6074219, + "special": false, + "text": " parse" + }, + { + "id": 1243, + "logprob": -1.6298828, + "special": false, + "text": " test" + }, + { + "id": 1206, + "logprob": -0.72558594, + "special": false, + "text": " case" + }, + { + "id": 1024, + "logprob": -0.40429688, + "special": false, + "text": " name" + }, + { + "id": 515, + "logprob": 0.0, + "special": false, + "text": " from" + }, + { + "id": 525, + "logprob": -1.2519531, + "special": false, + "text": " '" + } + ], + "top_tokens": null + }, + "generated_text": "Test request failed: Could not parse test case name from '" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json new file mode 100644 index 00000000..96a40fa4 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + } +] diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 0efb6693..d787873b 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream( chunk = [c.replace("data:", "") for c in chunk] # remove empty strings chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] # parse json chunk = [json.loads(c) for c in chunk] diff --git a/integration-tests/models/test_flash_llama_marlin_24.py b/integration-tests/models/test_flash_llama_marlin_24.py new file mode 100644 index 00000000..3eb94f02 --- /dev/null +++ b/integration-tests/models/test_flash_llama_marlin_24.py @@ -0,0 +1,66 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_marlin24_handle(launcher): + with launcher( + "nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_marlin(flash_llama_marlin24_handle): + await flash_llama_marlin24_handle.health(300) + return flash_llama_marlin24_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_load( + flash_llama_marlin, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_marlin, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 36b27be8..d2d2b76e 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -2,6 +2,7 @@ import pytest import torch from text_generation_server.utils.weights import ( DefaultWeightsLoader, + UnquantizedWeight, Weights, WeightsLoader, ) @@ -363,7 +364,10 @@ class MockWeights(Weights): self.process_group = process_group self.prefix = prefix self.weights_loader = ( - DefaultWeightsLoader() if weights_loader is None else weights_loader + # We don't need to get linear layers, so just wrap raw tensors. + DefaultWeightsLoader(lambda x: x) + if weights_loader is None + else weights_loader ) self._handles = {} @@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index ca39919c..925b0b2d 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -1,8 +1,11 @@ -import torch -from loguru import logger +from dataclasses import dataclass from functools import lru_cache + import bitsandbytes as bnb +import torch from bitsandbytes.nn import Int8Params, Params4bit +from loguru import logger +from text_generation_server.utils.weights import Weight @lru_cache(1) @@ -12,6 +15,14 @@ def warn_deprecate_bnb(): ) +@dataclass +class BNBWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0) + + class Linear8bitLt(torch.nn.Module): def __init__( self, @@ -70,6 +81,22 @@ class Linear8bitLt(torch.nn.Module): return out +@dataclass +class BNBFP4Weight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="fp4") + + +@dataclass +class BNBNF4Weight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="nf4") + + class Linear4bit(torch.nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index fd22b5c6..f003f914 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -1,5 +1,23 @@ +from dataclasses import dataclass + import torch from EETQ import quant_weights, w8_a16_gemm +from text_generation_server.utils.weights import Weight + + +@dataclass +class EETQWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + try: + from text_generation_server.layers.eetq import EETQLinear + + return EETQLinear(self.weight, bias) + except ImportError: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) class EETQLinear(torch.nn.Module): diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index 55cba1cc..1a5acfa8 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,12 +1,12 @@ -import torch -from typing import List, Union from dataclasses import dataclass +from typing import List, Union -from text_generation_server.utils.weights import WeightsLoader, Weights +import torch +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class Exl2Weight: +class Exl2Weight(Weight): """ Exllama2 exl2 quantized weights. """ @@ -25,6 +25,11 @@ class Exl2Weight: def device(self) -> torch.device: return self.q_weight.device + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.gptq import ExllamaQuantLinear + + return ExllamaQuantLinear(self, bias) + class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index b76af8f1..b56f568a 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,7 +1,8 @@ -from enum import Enum, auto +from dataclasses import dataclass import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weight def get_fp8_linear() -> torch.nn.Module: @@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +@dataclass +class Fp8Weight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return get_fp8_linear()(self.weight, bias) + + class Fp8Linear(torch.nn.Module): def __init__( self, diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index aaa7a68a..c98dbefc 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,24 +1,23 @@ -from dataclasses import dataclass -from loguru import logger import os +from dataclasses import dataclass from typing import List, Optional, Union -from safetensors import SafetensorError -from text_generation_server.utils.weights import Weights, WeightsLoader + import torch -from text_generation_server.utils.import_utils import ( - SYSTEM, -) +from loguru import logger +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class GPTQWeight: +class GPTQWeight(Weight): qweight: torch.Tensor qzeros: torch.Tensor scales: torch.Tensor g_idx: Optional[torch.Tensor] bits: int groupsize: int + use_awq_kernel: bool use_exllama: bool def __post_init__(self): @@ -29,6 +28,50 @@ class GPTQWeight: def device(self) -> torch.device: return self.qweight.device + def get_linear(self, bias: torch.Tensor): + if self.use_awq_kernel: + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + return WQLinear( + w_bit=self.bits, + group_size=self.groupsize, + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bias=bias, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + elif self.use_exllama: + try: + from text_generation_server.layers.gptq import ExllamaQuantLinear + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + return ExllamaQuantLinear(self, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + return QuantLinear( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + bias, + self.bits, + self.groupsize, + ) + try: major, _minor = torch.cuda.get_device_capability() @@ -45,6 +88,8 @@ elif CAN_EXLLAMA: if V2: from text_generation_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllamav2 import ( create_exllama_buffers, set_device, ) @@ -53,6 +98,8 @@ elif CAN_EXLLAMA: else: from text_generation_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllama import ( create_exllama_buffers, set_device, ) @@ -162,6 +209,7 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=False, ) @@ -255,6 +303,7 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) @@ -336,8 +385,8 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama = False from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, CAN_EXLLAMA, + HAS_EXLLAMA, GPTQWeight, ) @@ -389,6 +438,7 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index babd86b0..a97cc43a 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,7 +1,8 @@ from typing import Optional + import torch -from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +from torch.nn import functional as F if SYSTEM == "rocm": try: @@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module): return F.linear(inp, self.weight, self.bias) -def get_linear(weight, bias, quantize): - if quantize is None: +def get_linear(weight, bias): + # Weights that are loaded through methods that are not + # quantization-aware are still bare tensors. We may want + # to change this in the future. + if isinstance(weight, torch.Tensor): if SYSTEM == "rocm": - linear = FastLinearROCm(weight, bias) + return FastLinearROCm(weight, bias) else: - linear = FastLinear(weight, bias) - elif quantize == "eetq": - try: - from text_generation_server.layers.eetq import EETQLinear + return FastLinear(weight, bias) - linear = EETQLinear(weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - elif quantize == "fp8": - from text_generation_server.layers.fp8 import get_fp8_linear - - linear = get_fp8_linear()(weight, bias) - elif quantize == "bitsandbytes": - try: - from text_generation_server.layers.bnb import ( - warn_deprecate_bnb, - Linear8bitLt, - ) - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - warn_deprecate_bnb() - linear = Linear8bitLt( - weight, - bias, - has_fp16_weights=False, - threshold=6.0, - ) - if bias is not None: - linear.bias = nn.Parameter(bias) - elif quantize == "bitsandbytes-fp4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="fp4", - ) - elif quantize == "bitsandbytes-nf4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="nf4", - ) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - if not isinstance(weight, Exl2Weight): - raise NotImplementedError( - f"The passed weight is not `exl2` compatible, loader needs to be updated." - ) - - from text_generation_server.layers.gptq import ExllamaQuantLinear - - linear = ExllamaQuantLinear(weight, bias) - - elif quantize == "gptq": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlinLinear, - GPTQMarlinWeight, - ) - - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, GPTQWeight): - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - linear = ExllamaQuantLinear(weight, bias) - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, - ) - else: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." - ) - - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - if not isinstance(weight, GPTQWeight): - raise NotImplementedError( - f"The passed weight is not `awq` compatible, loader needs to be updated." - ) - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) - try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear - - linear = WQLinear( - w_bit=weight.bits, - group_size=weight.groupsize, - qweight=weight.qweight, - qzeros=weight.qzeros, - scales=weight.scales, - bias=bias, - ) - except ImportError: - raise NotImplementedError( - "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Linear, - GPTQMarlin24Weight, - MarlinLinear, - MarlinWeight, - ) - - if isinstance(weight, GPTQMarlin24Weight): - linear = GPTQMarlin24Linear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, MarlinWeight): - linear = MarlinLinear(weight=weight, bias=bias) - else: - raise NotImplementedError( - f"The passed weight is not `marlin` compatible, loader needs to be updated." - ) - else: - raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") - return linear + return weight.get_linear(bias) diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 9777a47e..e7f017a4 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -7,7 +7,7 @@ from loguru import logger from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weights, WeightsLoader +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels @@ -63,8 +63,7 @@ class MarlinWeightsLoader(WeightsLoader): return weight def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: + if self.is_marlin_24: try: B = torch.cat( [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 @@ -101,8 +100,7 @@ class MarlinWeightsLoader(WeightsLoader): return weight def get_weights_row(self, weights: Weights, prefix: str): - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: + if self.is_marlin_24: try: B = weights.get_sharded(f"{prefix}.B_24", dim=0) except RuntimeError: @@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor): @dataclass -class GPTQMarlinWeight: +class GPTQMarlinWeight(Weight): """ Repacked GPTQ Marlin weights. """ @@ -219,6 +217,12 @@ class GPTQMarlinWeight: assert self.g_idx.dtype == torch.int32 assert self.perm.dtype == torch.int32 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlinLinear( + weight=self, + bias=bias, + ) + def repack_gptq_for_marlin( *, @@ -376,6 +380,12 @@ class GPTQMarlin24Weight: assert self.B_meta.dtype == torch.int16 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlin24Linear( + weight=self, + bias=bias, + ) + class GPTQMarlin24Linear(nn.Module): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): @@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): @dataclass -class MarlinWeight: +class MarlinWeight(Weight): """ Marlin weights. @@ -583,6 +593,9 @@ class MarlinWeight: assert self.B.dtype == torch.int32 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return MarlinLinear(weight=self, bias=bias) + class MarlinLinear(nn.Module): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 011f105b..9dddb8ae 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer): quantize = config.quantize return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), + get_linear(weight, bias=None), process_group=weights.process_group, should_gather=should_gather, ) @@ -134,7 +134,7 @@ class TensorParallelColumnLinear(SuperLayer): raise NotImplementedError("packed_gate_up only implemented without bias") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -157,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer): raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -167,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -177,7 +177,7 @@ class TensorParallelColumnLinear(SuperLayer): for prefix in prefixes: weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None - linears.append(get_linear(weight, b, config.quantize)) + linears.append(get_linear(weight, b)) linear = LayerConcat(linears) else: weight = weights.get_multi_weights_col(prefixes, dim=dim) @@ -186,7 +186,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @@ -205,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer): else: bias = None return cls( - get_linear(weight, bias, config.quantize), + get_linear(weight, bias), process_group=weights.process_group, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 49f5a81b..c7b29d13 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class FlashCohereAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 44411687..7426fc55 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls): if cls == TensorParallelRowLinear: expert_slice = expert_slice.t().contiguous() - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear, weights.process_group)) else: - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear)) return experts diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a3ce5521..5273b15d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemma2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 34a7efa2..829ad427 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemmaAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index cbfcb1b8..a55a4af3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights): bias = torch.cat(tensors, dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def _load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads): 3 * num_heads * head_size ], f"{weight.shape} != {[3 * num_heads * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): @@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool): bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) @@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool): else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) class FlashGPT2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 78832341..5237a484 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional, Tuple import torch @@ -25,7 +26,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -42,10 +42,16 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, +) if SYSTEM == "rocm": try: @@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id): ) +@contextmanager +def no_fp8(weights: Weights): + weights_loader = weights.weights_loader + if ( + isinstance(weights_loader, DefaultWeightsLoader) + and weights_loader.weight_class is Fp8Weight + ): + weights_loader = DefaultWeightsLoader(UnquantizedWeight) + + with weights.use_loader(weights_loader): + yield + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -330,12 +349,15 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - self.self_attn = FlashLlamaAttention( - index=index, - prefix=f"{prefix}.self_attn", - config=config, - weights=weights, - ) + + with no_fp8(weights): + self.self_attn = FlashLlamaAttention( + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) @@ -470,23 +492,27 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" - ), - weights=weights, - ) + with no_fp8(weights): + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" - self.lm_head = SpeculativeHead.load( - config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, - ) + with no_fp8(weights): + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 49c0e903..a1e36fc7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) def _load_experts(config, prefix: str, mat, weights): diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 85dcb2a6..99664230 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: @@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6c508264..a1ce03b9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights): ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" # this is the same as llama except for Phi uses bias=True - return TensorParallelColumnLinear( - get_linear(weight, bias=True, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=True)) class FlashPhiAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 65b40fed..d7cad480 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.parallel_attn: return linear else: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 77b9d49c..2b939a10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -105,6 +105,7 @@ def _load_multi_mqa_gptq( g_idx=g_idx, bits=loader.bits, groupsize=loader.groupsize, + use_awq_kernel=loader.quantize == "awq", use_exllama=HAS_EXLLAMA, ) @@ -121,7 +122,7 @@ def _load_multi_mqa_gptq( bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -193,7 +194,7 @@ def _load_multi_mqa( assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_col(config, prefix: str, weights, bias: bool): @@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): @@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 19556f78..89471955 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class Starcoder2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index daf3329a..735c3899 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -34,7 +34,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, ) -from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -698,7 +698,7 @@ class Idefics2ForConditionalGeneration(nn.Module): self.dtype = weights.dtype # The vision and connector models are not quantized. - with weights.use_loader(DefaultWeightsLoader()): + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): self.vision_model = Idefics2VisionTransformer( prefix=( f"{prefix}.model.vision_model" if prefix else "model.vision_model" @@ -707,16 +707,12 @@ class Idefics2ForConditionalGeneration(nn.Module): weights=weights, ) - quantize = config.quantize - try: - config.quantize = None - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) - finally: - config.quantize = quantize + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) self.config = config self.image_seq_len = config.perceiver_config.resampler_n_latents diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index fb09a8f1..bbe9f45a 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias): bias = bias.to(device=weights.device) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return TensorParallelColumnLinear(linear) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 07975bea..e8e22db8 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -1,11 +1,14 @@ -from typing import Optional -import os import json +import os from dataclasses import dataclass +from typing import Optional from huggingface_hub import hf_hub_download - -from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + WeightsLoader, +) @dataclass @@ -104,10 +107,30 @@ def get_loader( quantize=quantize, sym=quantizer_config.sym, ) + elif quantize == "bitsandbytes": + from text_generation_server.layers.bnb import BNBWeight + + return DefaultWeightsLoader(BNBWeight) + elif quantize == "bitsandbytes-fp4": + from text_generation_server.layers.bnb import BNBFP4Weight + + return DefaultWeightsLoader(BNBFP4Weight) + elif quantize == "bitsandbytes-nf4": + from text_generation_server.layers.bnb import BNBNF4Weight + + return DefaultWeightsLoader(BNBNF4Weight) + elif quantize == "eetq": + from text_generation_server.layers.eetq import EETQWeight + + return DefaultWeightsLoader(EETQWeight) elif quantize == "exl2": from text_generation_server.layers.exl2 import Exl2WeightsLoader return Exl2WeightsLoader() + elif quantize == "fp8": + from text_generation_server.layers.fp8 import Fp8Weight + + return DefaultWeightsLoader(Fp8Weight) elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader @@ -115,5 +138,7 @@ def get_loader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) + elif quantize is None: + return DefaultWeightsLoader(UnquantizedWeight) else: - return DefaultWeightsLoader() + raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index b530af23..6876b700 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,9 +1,13 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum, auto from pathlib import Path from typing import Dict, List, Optional, Union -from safetensors import safe_open + import torch +from safetensors import safe_open +from text_generation_server.utils.import_utils import SYSTEM class WeightsLoader(ABC): @@ -62,7 +66,39 @@ class WeightsLoader(ABC): ... +class Weight(ABC): + """Instances of this type implement unquantized/quantized/to-be + quantized weights.""" + + @abstractmethod + def get_linear(self, bias: torch.Tensor): + """Create a linear layer from this weight.""" + ... + + +@dataclass +class UnquantizedWeight: + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.linear import FastLinear, FastLinearROCm + + if SYSTEM == "rocm": + return FastLinearROCm(self.weight, bias) + else: + return FastLinear(self.weight, bias) + + class DefaultWeightsLoader(WeightsLoader): + """Weight loader that loads (unquantized) Torch tensors.""" + + def __init__(self, weight_class): + """Create a loader. Weights will be wrapped using the given `weights_class`, + normally this will be `UnquantizedWeight`, but a quantizer-specific class + such as `Fp8Weight` can be used to quantize the weights during loading. + """ + self.weight_class = weight_class + """ Loader that uses tensors as-is with the exception of applying sharding and/or concatenation. @@ -74,16 +110,21 @@ class DefaultWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): - return weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes + + return self.weight_class( + weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ), ) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - return torch.cat(w, dim=dim) + return self.weight_class(torch.cat(w, dim=dim)) def get_weights_row(self, weights: "Weights", prefix: str): - return weights.get_sharded(f"{prefix}.weight", dim=1) + return self.weight_class( + weights.get_sharded(f"{prefix}.weight", dim=1), + ) class Weights: From 394f8c7d2be25f84e622c04061f5421fbb34da65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 12:55:59 +0200 Subject: [PATCH 137/340] Hotfix: fix of use of unquantized weights in Gemma GQA loading (#2255) --- .../models/custom_modeling/flash_gemma2_modeling.py | 9 +++++---- .../models/custom_modeling/flash_gemma_modeling.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 5273b15d..8526d515 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class Gemma2Config(PretrainedConfig): @@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 829ad427..dfe6510c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class GemmaConfig(PretrainedConfig): @@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) From ba0dfb6fb1c2b3d9d69fee69b9c12fd99378f079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 14:42:19 +0200 Subject: [PATCH 138/340] Hotfix: various GPT-based model fixes (#2256) --- server/text_generation_server/models/__init__.py | 5 +++++ .../models/custom_modeling/flash_neox_modeling.py | 15 +++++++++++---- .../custom_modeling/flash_starcoder2_modeling.py | 9 +++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ba980195..32156a06 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -573,6 +573,10 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + return FlashCausalLM( model_id=model_id, model_class=FlashGPTNeoXForCausalLM, @@ -582,6 +586,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, ) elif sharded: return CausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 99664230..623b164c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from transformers.models.gpt_neox import GPTNeoXConfig +from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( @@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight + + +class GPTNeoXConfig(TransformersGPTNeoXConfig): + attribute_map = { + "num_key_value_heads": "num_attention_heads", + } def load_row(config, prefix: str, weights, bias: bool): @@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): weight = weights.get_multi_weights_col([prefix], dim=0) - if isinstance(weight, torch.Tensor): + if isinstance(weight, UnquantizedWeight): # Only on non quantized versions - weight = ( - weight.view( + weight.weight = ( + weight.weight.view( num_heads, 3, head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 89471955..cfa891d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight class Starcoder2Config(PretrainedConfig): @@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.use_bias: w = [ From 990ea793c07b2c284afe5ac0947227f259aec398 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 14:42:35 +0200 Subject: [PATCH 139/340] Hotfix: fix MPT after recent refactor (#2257) --- .../models/causal_lm.py | 9 ++++- .../models/custom_modeling/mpt_modeling.py | 38 +++++++++---------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0ea82b1e..87b9969a 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -492,7 +492,7 @@ class CausalLMBatch(Batch): @dataclass -class CausalLMBatchKeysLast(Batch): +class CausalLMBatchKeysLast(CausalLMBatch): keys_head_dim_last: bool = False @@ -544,7 +544,12 @@ class CausalLM(Model): config.quantize = quantize config.speculator = speculator if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = config.pad_token_id + if config.pad_token_id is not None: + tokenizer.pad_token_id = config.pad_token_id + elif config.eos_token_id is not None: + tokenizer.pad_token_id = config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id torch.distributed.barrier(group=self.process_group) weights_loader = get_loader( diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index bbe9f45a..04c547b2 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module): weights, ): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop if self.n_heads % weights.process_group.size() != 0: raise ValueError( @@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module): def __init__(self, config, prefix, weights): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) self.Wqkv = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias @@ -614,9 +614,9 @@ class MPTBlock(nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.prefix = prefix - if config.attn_config["attn_type"] != "multihead_attention": + if config.attn_config.attn_type != "multihead_attention": raise NotImplementedError( - f"""Not implemented attn {config.attn_config["attn_type"]}""" + f"""Not implemented attn {config.attn_config.attn_type}""" ) resid_pdrop = config.resid_pdrop if config.no_bias: @@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel): self.world_size = weights.process_group.size() self.rank = weights.process_group.rank() self.n_heads = config.n_heads - self.attn_impl = config.attn_config["attn_impl"] - self.prefix_lm = config.attn_config["prefix_lm"] - self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] - self.alibi = config.attn_config["alibi"] - self.alibi_bias_max = config.attn_config["alibi_bias_max"] + self.attn_impl = config.attn_config.attn_impl + self.prefix_lm = config.attn_config.prefix_lm + self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id + self.alibi = config.attn_config.alibi + self.alibi_bias_max = config.attn_config.alibi_bias_max if config.init_device == "mixed": if dist.get_local_rank() == 0: config.init_device = "cpu" From e658d95c2391ff6242d4572cc4f931c0cfcea41a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 15:59:00 +0200 Subject: [PATCH 140/340] Hotfix: pass through model revision in `VlmCausalLM` (#2258) --- server/text_generation_server/models/vlm_causal_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ace48805..f869f8b5 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -261,7 +261,12 @@ class VlmCausalLM(FlashCausalLM): **processor_kwargs, ) self.batch_class = batch_class - super().__init__(model_id=model_id, **kwargs) + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **kwargs, + ) @property def batch_type(self) -> Type[VlmCausalLMBatch]: From 66f3de583e61bfdd8d0504b5e4b8edff07a286cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Fri, 19 Jul 2024 16:17:56 +0200 Subject: [PATCH 141/340] usage stats and crash reports (#2220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * draft of usage stats * fix wrong link * launcher doesn't need sysinfo dep * only tokenizer class instead of hole struct * unused import * fix clippy errors * update openAPI doc * cargo fmt * fix error in passing flags to router * try again to update docs * run pre-commit locally * Update router/src/main.rs Co-authored-by: Hugo Larcher * Update router/src/main.rs Co-authored-by: Hugo Larcher * on crash use anonymous error event * delete json_output and ngrok * more robust way of checking if is in container * more robust nvidia smi * parse xpu more robustly * fix errors * add nvidia-smi details in docs * cargo fmt * fix clippy * should make docs check pass * Update router/src/usage_stats.rs Co-authored-by: Hugo Larcher * error reason can't be in nested json * cargo fmt --------- Co-authored-by: Hugo Larcher Co-authored-by: Erik Kaunismäki --- Cargo.lock | 53 +++- docs/source/basic_tutorials/launcher.md | 16 ++ docs/source/usage_statistics.md | 73 +++++ launcher/src/main.rs | 16 ++ router/Cargo.toml | 4 + router/src/lib.rs | 8 +- router/src/main.rs | 86 +++++- router/src/usage_stats.rs | 355 ++++++++++++++++++++++++ 8 files changed, 599 insertions(+), 12 deletions(-) create mode 100644 docs/source/usage_statistics.md create mode 100644 router/src/usage_stats.rs diff --git a/Cargo.lock b/Cargo.lock index de3977c3..e7457b73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -772,6 +772,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctrlc" version = "3.4.4" @@ -3479,9 +3500,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.118" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -3738,15 +3759,16 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.30.12" +version = "0.30.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", + "rayon", "windows", ] @@ -3892,6 +3914,7 @@ dependencies = [ "axum-tracing-opentelemetry", "base64 0.22.1", "clap", + "csv", "futures", "futures-util", "hf-hub", @@ -3913,6 +3936,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "sysinfo", "text-generation-client", "thiserror", "tokenizers", @@ -3924,6 +3948,7 @@ dependencies = [ "tracing-subscriber", "utoipa", "utoipa-swagger-ui", + "uuid", "vergen", ] @@ -4587,9 +4612,25 @@ dependencies = [ [[package]] name = "uuid" -version = "1.9.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", + "rand", + "uuid-macro-internal", +] + +[[package]] +name = "uuid-macro-internal" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] [[package]] name = "v_frame" diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f..77f88490 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -424,6 +424,22 @@ Options: [env: LORA_ADAPTERS=] +``` +## DISABLE_USAGE_STATS +```shell + --disable-usage-stats + Disable sending of all usage statistics + + [env: DISABLE_USAGE_STATS=] + +``` +## DISABLE_CRASH_REPORTS +```shell + --disable-crash-reports + Disable sending of crash reports, but allow anonymous usage statistics + + [env: DISABLE_CRASH_REPORTS=] + ``` ## HELP ```shell diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md new file mode 100644 index 00000000..adf0d70f --- /dev/null +++ b/docs/source/usage_statistics.md @@ -0,0 +1,73 @@ + +# Collection of Usage Statistics + +Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted. + +Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine. + +## What data is collected + +The code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs). +As of release 2.1.2 this is an example of the data collected: + +- From the TGI configuration: +```json +{ + "event_type": "start", + "disable_grammar_support": false, + "max_batch_prefill_tokens": 4096, + "max_batch_size": null, + "max_batch_total_tokens": null, + "max_best_of": 2, + "max_client_batch_size": 4, + "max_concurrent_requests": 128, + "max_input_tokens": 1024, + "max_stop_sequences": 4, + "max_top_n_tokens": 5, + "max_total_tokens": 2048, + "max_waiting_tokens": 20, + "messages_api_enabled": false, + "model_config": { + "model_type": "Bloom" + }, + "revision": null, + "tokenizer_class": "BloomTokenizerFast", + "validation_workers": 2, + "waiting_served_ratio": 1.2, + "docker_label": "latest", + "git_sha": "cfc118704880453d29bcbe4fbbd91dda501cf5fe", + "nvidia_env": { + "name": "NVIDIA A10G", + "pci_bus_id": "00000000:00:1E.0", + "driver_version": "535.183.01", + "pstate": "P8", + "pcie_link_gen_max": "4", + "pcie_link_gen_current": "1", + "temperature_gpu": "31", + "utilization_gpu": "0 %", + "utilization_memory": "0 %", + "memory_total": "23028 MiB", + "memory_free": "22515 MiB", + "memory_used": "0 MiB", + "reset_status_reset_required": "No", + "reset_status_drain_and_reset_recommended": "No", + "compute_cap": "8.6", + "ecc_errors_corrected_volatile_total": "0", + "mig_mode_current": "[N/A]", + "power_draw_instant": "10.86 W", + "power_limit": "300.00 W" + }, + "system_env": { + "cpu_count": 16, + "cpu_type": "AMD EPYC 7R32", + "total_memory": 66681196544, + "architecture": "x86_64", + "platform": "linux-unix-x86_64" + } +} + +``` + +## How to opt-out + +You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics. diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 04358917..dc5ebaee 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -459,6 +459,14 @@ struct Args { /// startup that will be available to callers via the `adapter_id` field in a request. #[clap(long, env)] lora_adapters: Option, + + /// Disable sending of all usage statistics + #[clap(default_value = "false", long, env)] + disable_usage_stats: bool, + + /// Disable sending of crash reports, but allow anonymous usage statistics + #[clap(default_value = "false", long, env)] + disable_crash_reports: bool, } #[derive(Debug)] @@ -1205,6 +1213,14 @@ fn spawn_webserver( args.model_id, ]; + // Pass usage stats flags to router + if args.disable_usage_stats { + router_args.push("--disable-usage-stats".to_string()); + } + if args.disable_crash_reports { + router_args.push("--disable-crash-reports".to_string()); + } + // Grammar support if args.disable_grammar_support { router_args.push("--disable-grammar-support".to_string()); diff --git a/router/Cargo.toml b/router/Cargo.toml index 60fb5c9d..0fc700a0 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -52,6 +52,10 @@ regex = "1.10.3" once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } +sysinfo = "0.30.13" +uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +csv = "1.3.0" + [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/lib.rs b/router/src/lib.rs index f856406d..52926c6c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,8 @@ mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod usage_stats; + use serde::{Deserialize, Serialize}; use tracing::warn; use utoipa::ToSchema; @@ -40,13 +42,13 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatTemplate { name: String, template: String, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum ChatTemplateVersions { Single(String), @@ -55,7 +57,7 @@ pub enum ChatTemplateVersions { use std::path::Path; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, diff --git a/router/src/main.rs b/router/src/main.rs index b060d73c..bfc77913 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,6 +14,7 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_router::config::Config; +use text_generation_router::usage_stats; use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; @@ -87,6 +88,10 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + #[clap(long, env, default_value_t)] + disable_usage_stats: bool, + #[clap(long, env, default_value_t)] + disable_crash_reports: bool, } #[derive(Debug, Subcommand)] @@ -128,6 +133,8 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + disable_usage_stats, + disable_crash_reports, command, } = args; @@ -324,6 +331,7 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); + let tokenizer_class = tokenizer_config.tokenizer_class.clone(); let tokenizer: Option = tokenizer_filename.and_then(|filename| { let mut tokenizer = Tokenizer::from_file(filename).ok(); @@ -378,8 +386,47 @@ async fn main() -> Result<(), RouterError> { } }; + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + + let user_agent = if !disable_usage_stats && is_container { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_class, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } else { + None + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + // Run server - server::run( + let result = server::run( master_shard_uds_path, model_info, compat_return_full_text, @@ -410,8 +457,41 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, print_schema_command, ) - .await?; - Ok(()) + .await; + + match result { + Ok(_) => { + if let Some(ref ua) = user_agent { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + }; + Ok(()) + } + Err(e) => { + if let Some(ref ua) = user_agent { + if !disable_crash_reports { + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some(e.to_string()), + ); + error_event.send().await; + } else { + let unknow_error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some("unknow_error".to_string()), + ); + unknow_error_event.send().await; + } + }; + Err(RouterError::WebServer(e)) + } + } } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs new file mode 100644 index 00000000..8559ae90 --- /dev/null +++ b/router/src/usage_stats.rs @@ -0,0 +1,355 @@ +use crate::config::Config; +use csv::ReaderBuilder; +use reqwest::header::HeaderMap; +use serde::Serialize; +use std::{ + fs::File, + io::{self, BufRead}, + path::Path, + process::Command, + time::Duration, +}; +use uuid::Uuid; + +const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; + +#[derive(Debug, Clone, Serialize)] +pub struct UserAgent { + pub uid: String, + pub args: Args, + pub env: Env, +} + +impl UserAgent { + pub fn new(reduced_args: Args) -> Self { + Self { + uid: Uuid::new_v4().to_string(), + args: reduced_args, + env: Env::new(), + } + } +} + +#[derive(Serialize, Debug)] +pub enum EventType { + Start, + Stop, + Error, +} + +#[derive(Debug, Serialize)] +pub struct UsageStatsEvent { + user_agent: UserAgent, + event_type: EventType, + #[serde(skip_serializing_if = "Option::is_none")] + error_reason: Option, +} + +impl UsageStatsEvent { + pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option) -> Self { + Self { + user_agent, + event_type, + error_reason, + } + } + pub async fn send(&self) { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + let body = serde_json::to_string(&self).unwrap(); + let client = reqwest::Client::new(); + let _ = client + .post(TELEMETRY_URL) + .headers(headers) + .body(body) + .timeout(Duration::from_secs(5)) + .send() + .await; + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct Args { + model_config: Option, + tokenizer_config: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + disable_usage_stats: bool, + disable_crash_reports: bool, +} + +impl Args { + #[allow(clippy::too_many_arguments)] + pub fn new( + model_config: Option, + tokenizer_config: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + disable_usage_stats: bool, + disable_crash_reports: bool, + ) -> Self { + Self { + model_config, + tokenizer_config, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + } + } +} + +/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency +#[derive(Serialize, Debug, Clone)] +pub struct Env { + git_sha: &'static str, + docker_label: &'static str, + nvidia_info: Option>, + xpu_info: Option>, + system_env: SystemInfo, +} + +#[derive(Debug, Serialize, Clone)] +struct NvidiaSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + pstate: String, + pcie_link_gen_max: String, + pcie_link_gen_current: String, + temperature_gpu: String, + utilization_gpu: String, + utilization_memory: String, + memory_total: String, + memory_free: String, + memory_used: String, + reset_status_reset_required: String, + reset_status_drain_and_reset_recommended: String, + compute_cap: String, + ecc_errors_corrected_volatile_total: String, + mig_mode_current: String, + power_draw_instant: String, + power_limit: String, +} + +impl NvidiaSmiInfo { + fn new() -> Option> { + let output = Command::new("nvidia-smi") + .args([ + "--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(NvidiaSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + pstate: record[3].to_string(), + pcie_link_gen_max: record[4].to_string(), + pcie_link_gen_current: record[5].to_string(), + temperature_gpu: record[6].to_string(), + utilization_gpu: record[7].to_string(), + utilization_memory: record[8].to_string(), + memory_total: record[9].to_string(), + memory_free: record[10].to_string(), + memory_used: record[11].to_string(), + reset_status_reset_required: record[12].to_string(), + reset_status_drain_and_reset_recommended: record[13].to_string(), + compute_cap: record[14].to_string(), + ecc_errors_corrected_volatile_total: record[15].to_string(), + mig_mode_current: record[16].to_string(), + power_draw_instant: record[17].to_string(), + power_limit: record[18].to_string(), + }); + } + + Some(infos) + } +} + +#[derive(Debug, Serialize, Clone)] +struct XpuSmiInfo { + device_id: usize, + gpu_utilization: f32, + gpu_power: f32, + gpu_core_temperature: f32, + gpu_memory_bandwidth_utilization: f32, +} + +impl XpuSmiInfo { + /// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format + fn new() -> Option> { + let output = Command::new("xpu-smi") + .args([ + "dump", "-d", "-1", "-m", + "0,1,3,17", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization + "-n", "1", "-j", + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + let mut infos = Vec::new(); + + let json_data: serde_json::Value = match serde_json::from_str(&stdout) { + Ok(data) => data, + Err(_) => return None, + }; + + if let Some(metrics_data) = json_data.as_array() { + for entry in metrics_data { + let device_id = entry["deviceId"].as_u64()? as usize; + let gpu_utilization = entry["metrics"][0].as_f64()? as f32; + let gpu_power = entry["metrics"][1].as_f64()? as f32; + let gpu_core_temperature = entry["metrics"][2].as_f64()? as f32; + let gpu_memory_bandwidth_utilization = entry["metrics"][3].as_f64()? as f32; + + infos.push(XpuSmiInfo { + device_id, + gpu_utilization, + gpu_power, + gpu_core_temperature, + gpu_memory_bandwidth_utilization, + }); + } + } + + Some(infos) + } +} + +#[derive(Serialize, Debug, Clone)] +pub struct SystemInfo { + cpu_count: usize, + cpu_type: String, + total_memory: u64, + architecture: String, + platform: String, +} + +impl SystemInfo { + fn new() -> Self { + let mut system = sysinfo::System::new_all(); + system.refresh_all(); + + let cpu_count = system.cpus().len(); + let cpu_type = system.cpus()[0].brand().to_string(); + let total_memory = system.total_memory(); + let architecture = std::env::consts::ARCH.to_string(); + let platform = format!( + "{}-{}-{}", + std::env::consts::OS, + std::env::consts::FAMILY, + std::env::consts::ARCH + ); + Self { + cpu_count, + cpu_type, + total_memory, + architecture, + platform, + } + } +} + +impl Default for Env { + fn default() -> Self { + Self::new() + } +} + +impl Env { + pub fn new() -> Self { + Self { + system_env: SystemInfo::new(), + nvidia_info: NvidiaSmiInfo::new(), + xpu_info: XpuSmiInfo::new(), + git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), + docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), + } + } +} + +pub fn is_container() -> io::Result { + let path = Path::new("/proc/self/cgroup"); + let file = File::open(path)?; + let reader = io::BufReader::new(file); + + for line in reader.lines() { + let line = line?; + // Check for common container runtimes + if line.contains("/docker/") + || line.contains("/docker-") + || line.contains("/kubepods/") + || line.contains("/kubepods-") + || line.contains("containerd") + || line.contains("crio") + || line.contains("podman") + { + return Ok(true); + } + } + Ok(false) +} From 8afc17396d338d43eef24e538c7bfc59a3483a9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Fri, 19 Jul 2024 16:34:04 +0200 Subject: [PATCH 142/340] add usage stats to toctree (#2260) quick fix --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 119c5662..e97c00aa 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -21,6 +21,8 @@ title: Messages API - local: architecture title: Internal Architecture + - local: usage_statistics + title: Usage Statistics title: Getting started - sections: - local: basic_tutorials/consuming_tgi From 898a89208246297fdbbfcaa711bc6a54939ab1c9 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 19 Jul 2024 11:12:02 -0400 Subject: [PATCH 143/340] fix: adjust default tool choice (#2244) * fix: adjust default tool choice * feat: improve tool choice syntax and response parsing/errors * fix: remove dev tests * feat: add ToolChoice to docs --- docs/openapi.json | 15 ++- router/src/infer/mod.rs | 219 ++++++++++++++++++++-------------------- router/src/lib.rs | 20 ++-- router/src/server.rs | 53 +++++----- 4 files changed, 160 insertions(+), 147 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 113fd8b5..873044a7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -909,7 +909,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -2035,6 +2035,14 @@ } } }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ { @@ -2055,6 +2063,11 @@ "$ref": "#/components/schemas/FunctionName" } } + }, + { + "type": "object", + "default": null, + "nullable": true } ] }, diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index f3b10450..db9070d4 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, }; use crate::{ FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, @@ -332,126 +332,131 @@ impl ChatTemplate { pub struct ToolGrammar {} impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + pub fn apply( tools: Option>, - tool_choice: Option, + tool_choice: ToolChoice, ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::Function { function } => { - let tool = req_tools - .iter() - .find(|tool| tool.function.name == function.name) - .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) - .clone(); - vec![tool] - } - ToolType::OneOf => req_tools.to_owned(), - }; + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" }), - )]) - .collect(); + ); - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 52926c6c..b6e0d09d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -826,7 +826,7 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: Option, + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// @@ -848,6 +848,7 @@ pub enum ToolType { OneOf, FunctionName(String), Function { function: FunctionName }, + NoTool, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -855,27 +856,26 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[serde(from = "ToolTypeDeserializer")] pub struct ToolChoice(pub Option); #[derive(Deserialize)] #[serde(untagged)] enum ToolTypeDeserializer { - None(Option), - Some(ToolType), + String(String), + ToolType(ToolType), } impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::None(opt) => match opt.as_deref() { - Some("none") => ToolChoice(None), - Some("auto") => ToolChoice(Some(ToolType::OneOf)), - Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), - None => ToolChoice(Some(ToolType::OneOf)), + ToolTypeDeserializer::String(s) => match s.as_str() { + "none" => ToolChoice(Some(ToolType::NoTool)), + "auto" => ToolChoice(Some(ToolType::OneOf)), + _ => ToolChoice(Some(ToolType::FunctionName(s))), }, - ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), + ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } } diff --git a/router/src/server.rs b/router/src/server.rs index d3a280ca..c56c39a3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -24,7 +24,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -1192,39 +1192,33 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - // gen_text should be valid json - let gen_text_value: Value = - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; + let gen_text_value: Value = serde_json::from_str(&generation.generated_text) + .map_err(|e| InferError::ToolError(e.to_string()))?; + + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: gen_text_value - .get("function") - .and_then(|f| f.get("_name")) - .and_then(|name| name.as_str()) - .unwrap_or("default_function_name") - .to_string(), - // Serialize the JSON object obtained from "function" to an escaped JSON string - arguments: gen_text_value - .get("function") - .map(|f| { - let mut f_cloned = f.clone(); - if let Value::Object(ref mut props) = f_cloned { - props.remove("_name"); - } - f_cloned - }) - .unwrap_or_default(), + name, + arguments, }, }]; (Some(tool_calls), None) @@ -1498,6 +1492,7 @@ pub async fn run( ToolCall, Function, FunctionDefinition, + ToolChoice, ) ), tags( From c1638a56f13e13bf6baafbff22ee047561240f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 17:23:20 +0200 Subject: [PATCH 144/340] Add support for Deepseek V2 (#2224) Deepseek V2 is a MoE model from Deepseek. Relevant variations compared to other models: - Grouped top-K in expert selection. - mscale in yarn is calculated using the `mscale` and `mscale_all_dim` configuration options. - `mscale_all_dim` is also used in scaling attention softmax. - Permuting of the query/key representations before applying rotary embeddings. - Some projections cannot be sharded (`q_a_proj`, `kv_a_proj_with_mqa`). So, we need weight loads that supports quantized weights. To this end `{Weights,WeightLoader}.get_weight` was added. - The query/key head dimensionality differs from that of the value, so we need to pad during attention. - Heads with size 192, needs an extension to our paged attention fork and we need to ensure that the KV cache is allocated with the correct size. - Shared experts. --- docs/source/supported_models.md | 1 + .../test_flash_deepseek_v2.json | 89 ++ .../test_flash_deepseek_v2_all_params.json | 89 ++ .../test_flash_deepseek_v2_load.json | 358 +++++++ .../models/test_flash_deepseek_v2.py | 63 ++ server/Makefile-vllm | 6 +- server/text_generation_server/layers/exl2.py | 66 +- .../layers/gptq/__init__.py | 109 ++ .../text_generation_server/layers/marlin.py | 29 + .../text_generation_server/layers/rotary.py | 27 +- .../text_generation_server/models/__init__.py | 44 +- .../flash_deepseek_v2_modeling.py | 983 ++++++++++++++++++ .../models/flash_causal_lm.py | 10 +- .../text_generation_server/utils/weights.py | 13 + 14 files changed, 1836 insertions(+), 51 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json create mode 100644 integration-tests/models/test_flash_deepseek_v2.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 2bdd00de..bc124f31 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware ## Supported Models +- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json new file mode 100644 index 00000000..03f90367 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.84375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.34375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.8359375, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.0859375, + "special": false, + "text": " is" + }, + { + "id": 254, + "logprob": -1.5390625, + "special": false, + "text": " the" + }, + { + "id": 1022, + "logprob": -1.1875, + "special": false, + "text": " first" + }, + { + "id": 3458, + "logprob": -0.35546875, + "special": false, + "text": " step" + }, + { + "id": 279, + "logprob": -0.8828125, + "special": false, + "text": " in" + }, + { + "id": 254, + "logprob": -0.71484375, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is the first step in the" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json new file mode 100644 index 00000000..e84135cf --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2143, + "logprob": -1.828125, + "special": false, + "text": " sent" + }, + { + "id": 10081, + "logprob": -0.36914062, + "special": false, + "text": " successfully" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 185, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 1380, + "logprob": -0.38671875, + "special": false, + "text": "We" + }, + { + "id": 543, + "logprob": -0.12695312, + "special": false, + "text": " will" + }, + { + "id": 752, + "logprob": -0.20117188, + "special": false, + "text": " get" + }, + { + "id": 279, + "logprob": 0.0, + "special": false, + "text": " in" + }, + { + "id": 5402, + "logprob": 0.0, + "special": false, + "text": " touch" + }, + { + "id": 366, + "logprob": 0.0, + "special": false, + "text": " with" + } + ], + "top_tokens": null + }, + "generated_text": "Test request sent successfully.\nWe will get in touch with" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json new file mode 100644 index 00000000..a4b9784a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + } +] diff --git a/integration-tests/models/test_flash_deepseek_v2.py b/integration-tests/models/test_flash_deepseek_v2.py new file mode 100644 index 00000000..010e08c9 --- /dev/null +++ b/integration-tests/models/test_flash_deepseek_v2.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_deepseek_v2_handle(launcher): + with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_deepseek_v2(flash_deepseek_v2_handle): + await flash_deepseek_v2_handle.health(300) + return flash_deepseek_v2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_load( + flash_deepseek_v2, generate_load, response_snapshot +): + responses = await generate_load( + flash_deepseek_v2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2f2b5ef6..f1f80529 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,14 +1,14 @@ -commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa +commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ git clone https://github.com/Narsil/vllm.git vllm; \ fi - cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build + cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build install-vllm-cuda: build-vllm-cuda - cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . + cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e . build-vllm-rocm: if [ ! -d 'vllm' ]; then \ diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index 1a5acfa8..a6e07f45 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -34,6 +34,30 @@ class Exl2Weight(Weight): class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + def get_weights_col_packed( self, weights: Weights, @@ -43,46 +67,12 @@ class Exl2WeightsLoader(WeightsLoader): raise RuntimeError("Column-packed weights are not supported for exl") def get_weights_col(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") def get_weights_row(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index c98dbefc..6feca275 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -134,6 +134,115 @@ class GPTQWeightsLoader(WeightsLoader): self.quantize = quantize self.sym = sym + def get_weights(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index e7f017a4..a913ff57 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -33,6 +33,35 @@ class MarlinWeightsLoader(WeightsLoader): self.bits = bits self.is_marlin_24 = is_marlin_24 + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 8c354b82..db78ee1c 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,6 +1,7 @@ import os import torch from torch import nn +from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -97,6 +98,8 @@ class PositionRotaryEmbedding(nn.Module): ) elif rope_scaling["type"] == "yarn": scaling_factor = rope_scaling["factor"] + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -109,6 +112,8 @@ class PositionRotaryEmbedding(nn.Module): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( @@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) else: raise NotImplementedError( @@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim): return ramp_func -def get_mscale(scale=1): +def get_mscale(scale: float = 1.0, mscale: float = 1.0): if scale <= 1: return 1.0 - return 0.1 * math.log(scale) + 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): @@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): attn_factor, beta_fast, beta_slow, + mscale: float, + mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) @@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor + get_mscale(self.scaling_factor, mscale) + / get_mscale(self.scaling_factor, mscale_all_dim) + * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): @@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): - if seqlen > self.max_position_embeddings: + if seqlen > self.max_position_embeddings or True: inv_freq_extrapolation = _create_inv_freq( self.dim, self.base, self.inv_freq.device ) @@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.base, self.max_position_embeddings, ) + inv_freq_mask = ( 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation @@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ) self.inv_freq = inv_freq - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 32156a06..690a8887 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -61,6 +61,10 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) @@ -141,6 +145,11 @@ if MAMBA_AVAILABLE: class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } IDEFICS2 = { "type": "idefics2", "name": "Idefics 2", @@ -459,7 +468,40 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - if model_type == MAMBA: + if model_type == DEEPSEEK_V2: + if FLASH_ATTENTION: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == MAMBA: return Mamba( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py new file mode 100644 index 00000000..8ffbd143 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -0,0 +1,983 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention.common import Seqlen +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + + +class DeepseekV2Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def _load_experts(config, prefix: str, mat: str, weights: Weights): + if config.quantize is not None: + raise NotImplementedError( + "Deepseek V2 does not support weight quantization yet." + ) + + assert mat in ["gate_proj", "up_proj", "down_proj"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.moe_intermediate_size % world_size == 0 + ), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.moe_intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.n_routed_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.n_routed_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "down_proj": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class DeepseekV2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + quantize=config.quantize, + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + quantize=config.quantize, + ) + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + key, + value, + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV2MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config: DeepseekV2Config, weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + gate_proj = _load_experts( + config, f"{prefix}.experts", "gate_proj", weights + ).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + + up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view( + self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim + ) + + self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1) + + self.down_proj = ( + _load_experts(config, f"{prefix}.experts", "down_proj", weights) + .view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + out = ( + fused_experts( + x, + self.gate_up_proj, + self.down_proj, + topk_weights, + topk_ids, + inplace=True, + ) + * self.routed_scaling_factor + ) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DenseMoE(nn.Module): + def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + # + # Seems like no one quantizes the gate. + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.experts = [ + DeepseekV2MLP( + f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size + ) + for i in range(self.n_routed_experts) + ] + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + # gate_logits: (sequence_length, n_experts) + router_logits = self.gate(x) + + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + def moe_infer_gpu( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ): + weights = torch.zeros( + topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device + ) + weights.scatter_(1, topk_ids, topk_weight) + + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i, expert in enumerate(self.experts): + # Add expert output to out with masking + out += expert(x, reduce=False) * weights[:, i].view(-1, 1) + return out + + +class DeepseekV2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.mlp = moe_cls(f"{prefix}.mlp", config, weights) + else: + self.mlp = DeepseekV2MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV2Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits + + +# Functions below are from vLLM: +# +# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 +# +# Remove after we have synced our version with upstream. + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2ca9eef3..2888f1f7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -839,7 +839,9 @@ class FlashCausalLM(Model): default_dtype=torch.float16, aliases=None, # Used for Santacoder override of config - num_kv_heads=None, + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -922,7 +924,11 @@ class FlashCausalLM(Model): else num_kv_heads ) assert self.num_kv_heads > 0 - self.head_size = config.hidden_size // config.num_attention_heads + + if head_size is None: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size self.cuda_graphs = {} self.kv_cache = [] diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6876b700..91592df0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -21,6 +21,13 @@ class WeightsLoader(ABC): with the format, etc. """ + @abstractmethod + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + ... + @abstractmethod def get_weights_col_packed( self, @@ -104,6 +111,9 @@ class DefaultWeightsLoader(WeightsLoader): and/or concatenation. """ + def get_weights(self, weights: "Weights", prefix: str): + return weights.get_tensor(f"{prefix}.weight") + def get_weights_col_packed( self, weights: "Weights", @@ -299,6 +309,9 @@ class Weights: return tensor + def get_weights(self, prefix: str): + return self.weights_loader.get_weights(self, prefix) + def get_weights_col_packed_qkv( self, prefix: str, From 50149c380088a3374174c3f02e69f999c1d65c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Sat, 20 Jul 2024 12:26:06 +0200 Subject: [PATCH 145/340] Add FP8 release test (#2261) --- .../test_flash_llama_fp8.json | 89 +++++ .../test_flash_llama_fp8_all_params.json | 89 +++++ .../test_flash_llama_fp8_load.json | 358 ++++++++++++++++++ .../models/test_flash_llama_fp8.py | 62 +++ 4 files changed, 598 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json create mode 100644 integration-tests/models/test_flash_llama_fp8.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json new file mode 100644 index 00000000..85cfb91f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7900391, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3554688, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0039062, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.4489746, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8100586, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013015747, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json new file mode 100644 index 00000000..dcb4d063 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 25, + "logprob": -0.8535156, + "special": false, + "text": ":" + }, + { + "id": 2209, + "logprob": -2.4804688, + "special": false, + "text": " Is" + }, + { + "id": 279, + "logprob": -0.7167969, + "special": false, + "text": " the" + }, + { + "id": 734, + "logprob": -2.625, + "special": false, + "text": " function" + }, + { + "id": 330, + "logprob": -0.35131836, + "special": false, + "text": " \"" + }, + { + "id": 4110, + "logprob": -2.4101562, + "special": false, + "text": "Create" + }, + { + "id": 264, + "logprob": -0.23181152, + "special": false, + "text": " a" + }, + { + "id": 502, + "logprob": -0.25512695, + "special": false, + "text": " new" + }, + { + "id": 1052, + "logprob": -1.2792969, + "special": false, + "text": " file" + }, + { + "id": 1, + "logprob": -1.2529297, + "special": false, + "text": "\"" + } + ], + "top_tokens": null + }, + "generated_text": "Test request: Is the function \"Create a new file\"" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json new file mode 100644 index 00000000..36c87c09 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + } +] diff --git a/integration-tests/models/test_flash_llama_fp8.py b/integration-tests/models/test_flash_llama_fp8.py new file mode 100644 index 00000000..fe5df590 --- /dev/null +++ b/integration-tests/models/test_flash_llama_fp8.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_fp8_handle(launcher): + with launcher("meta-llama/Meta-Llama-3-8B", num_shard=2, quantize="fp8") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_fp8(flash_llama_fp8_handle): + await flash_llama_fp8_handle.health(300) + return flash_llama_fp8_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): + responses = await generate_load( + flash_llama_fp8, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot From 85f10ec5c91a15f3f14d6319d64fbfd5dddd45b8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Sat, 20 Jul 2024 17:02:04 +0000 Subject: [PATCH 146/340] feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248) * feat(fp8): add support for fbgemm * allow loading fp8 weights directly * update outlines * fix makefile * build fbgemm * avoid circular import and fix dockerfile * add default dtype * refactored weights loader * fix auto conversion * fix quantization config parsing * force new nccl on install * missing get_weights implementation * increase timeout --- server/Makefile | 7 +- server/Makefile-fbgemm | 15 + server/fbgemm_remove_unused.patch | 306 ++++++++++++++++++ server/fix_torch90a.sh | 11 + server/text_generation_server/cli.py | 11 +- .../layers/attention/rocm.py | 6 +- server/text_generation_server/layers/bnb.py | 16 +- server/text_generation_server/layers/eetq.py | 4 +- server/text_generation_server/layers/fp8.py | 185 ++++++++++- .../text_generation_server/layers/marlin.py | 13 +- .../text_generation_server/models/__init__.py | 23 +- .../custom_modeling/flash_llama_modeling.py | 50 ++- .../models/flash_causal_lm.py | 31 +- .../text_generation_server/models/globals.py | 12 +- server/text_generation_server/models/model.py | 11 +- .../models/vlm_causal_lm.py | 7 +- server/text_generation_server/utils/dist.py | 6 +- server/text_generation_server/utils/log.py | 13 +- .../utils/quantization.py | 41 ++- .../text_generation_server/utils/weights.py | 58 +++- 20 files changed, 721 insertions(+), 105 deletions(-) create mode 100644 server/Makefile-fbgemm create mode 100644 server/fbgemm_remove_unused.patch create mode 100755 server/fix_torch90a.sh diff --git a/server/Makefile b/server/Makefile index 33940655..209fc44e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,6 +5,7 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica +include Makefile-fbgemm unit-tests: pytest -s -vv -m "not private" tests @@ -27,8 +28,9 @@ install-server: gen-server install: install-cuda echo "Installed server" -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm pip install -e ".[bnb]" + pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm @@ -36,5 +38,6 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes --with cuda + poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes + poetry export -o requirements_intel.txt --without-hashes diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm new file mode 100644 index 00000000..38f8f31f --- /dev/null +++ b/server/Makefile-fbgemm @@ -0,0 +1,15 @@ +fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca + +build-fbgemm: + chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \ + git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ + cp fbgemm_remove_unused.patch fbgemm && \ + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build + +install-fbgemm: build-fbgemm + cd fbgemm/fbgemm_gpu && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/fbgemm_remove_unused.patch b/server/fbgemm_remove_unused.patch new file mode 100644 index 00000000..ad6af811 --- /dev/null +++ b/server/fbgemm_remove_unused.patch @@ -0,0 +1,306 @@ +diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt +index 2244ea6f..96265a48 100644 +--- a/fbgemm_gpu/CMakeLists.txt ++++ b/fbgemm_gpu/CMakeLists.txt +@@ -94,14 +94,14 @@ endif() + # Build Experimental Modules + ################################################################################ + +-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) +- # TODO: Figure out NCCL/RCCL integration with ROCm +- add_subdirectory(experimental/example) +-endif() +- +-if(NOT FBGEMM_CPU_ONLY) +- add_subdirectory(experimental/gemm) +-endif() ++# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) ++# # TODO: Figure out NCCL/RCCL integration with ROCm ++# add_subdirectory(experimental/example) ++# endif() ++ ++# if(NOT FBGEMM_CPU_ONLY) ++# add_subdirectory(experimental/gemm) ++# endif() + + if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) + # CUTLASS currently doesn't build on ROCm and CK hasnt yet been added: +diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake +index c56773fe..0c0d349e 100644 +--- a/fbgemm_gpu/FbgemmGpu.cmake ++++ b/fbgemm_gpu/FbgemmGpu.cmake +@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources} + ################################################################################ + + set(fbgemm_gpu_sources_static_cpu +- codegen/training/forward/embedding_forward_split_cpu.cpp +- codegen/inference/embedding_forward_quantized_host_cpu.cpp +- codegen/training/backward/embedding_backward_dense_host_cpu.cpp +- codegen/utils/embedding_bounds_check_host_cpu.cpp +- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +- src/input_combine_ops/input_combine_cpu.cpp +- src/layout_transform_ops/layout_transform_ops_cpu.cpp ++ # codegen/training/forward/embedding_forward_split_cpu.cpp ++ # codegen/inference/embedding_forward_quantized_host_cpu.cpp ++ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp ++ # codegen/utils/embedding_bounds_check_host_cpu.cpp ++ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp ++ # src/input_combine_ops/input_combine_cpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_cpu.cpp + src/quantize_ops/quantize_ops_cpu.cpp + src/quantize_ops/quantize_ops_meta.cpp +- src/sparse_ops/sparse_ops_cpu.cpp +- src/sparse_ops/sparse_ops_meta.cpp +- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp +- src/split_embeddings_cache/linearize_cache_indices.cpp +- src/split_embeddings_cache/lfu_cache_populate_byte.cpp +- src/split_embeddings_cache/lru_cache_populate_byte.cpp +- src/split_embeddings_cache/lxu_cache.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++ # src/sparse_ops/sparse_ops_cpu.cpp ++ # src/sparse_ops/sparse_ops_meta.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp ++ # src/split_embeddings_cache/linearize_cache_indices.cpp ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lru_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lxu_cache.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++) + + if(NOT FBGEMM_CPU_ONLY) + list(APPEND fbgemm_gpu_sources_static_cpu +- codegen/inference/embedding_forward_quantized_host.cpp +- codegen/utils/embedding_bounds_check_host.cpp +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp +- src/layout_transform_ops/layout_transform_ops_gpu.cpp +- src/memory_utils/memory_utils.cpp +- src/memory_utils/memory_utils_ops.cpp +- src/memory_utils/memory_utils_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp ++ # codegen/inference/embedding_forward_quantized_host.cpp ++ # codegen/utils/embedding_bounds_check_host.cpp ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_gpu.cpp ++ # src/memory_utils/memory_utils.cpp ++ # src/memory_utils/memory_utils_ops.cpp ++ # src/memory_utils/memory_utils_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp + src/quantize_ops/quantize_ops_gpu.cpp +- src/sparse_ops/sparse_ops_gpu.cpp +- src/split_embeddings_utils/split_embeddings_utils.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cu +- src/metric_ops/metric_ops_host.cpp +- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp +- src/input_combine_ops/input_combine_gpu.cpp +- codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ # src/sparse_ops/sparse_ops_gpu.cpp ++ # src/split_embeddings_utils/split_embeddings_utils.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cu ++ # src/metric_ops/metric_ops_host.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp ++ # src/input_combine_ops/input_combine_gpu.cpp ++ # codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ ) + + if(NVML_LIB_PATH OR USE_ROCM) + message(STATUS "Adding merge_pooled_embeddings sources") +@@ -516,36 +518,36 @@ endif() + + if(NOT FBGEMM_CPU_ONLY) + set(fbgemm_gpu_sources_static_gpu +- codegen/utils/embedding_bounds_check.cu +- codegen/inference/embedding_forward_quantized_split_lookup.cu +- src/embedding_inplace_ops/embedding_inplace_update.cu +- src/histogram_binning_calibration_ops.cu +- src/input_combine_ops/input_combine.cu +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu +- src/memory_utils/memory_utils.cu +- src/memory_utils/memory_utils_ops.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu +- src/jagged_tensor_ops/dense_to_jagged_forward.cu +- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu +- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu +- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu +- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu +- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu +- src/jagged_tensor_ops/jagged_softmax_backward.cu +- src/jagged_tensor_ops/jagged_softmax_forward.cu +- src/jagged_tensor_ops/jagged_tensor_ops.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu +- src/jagged_tensor_ops/jagged_unique_indices.cu +- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +- src/layout_transform_ops/layout_transform_ops.cu +- src/metric_ops/metric_ops.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu +- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu ++ # codegen/utils/embedding_bounds_check.cu ++ # codegen/inference/embedding_forward_quantized_split_lookup.cu ++ # src/embedding_inplace_ops/embedding_inplace_update.cu ++ # src/histogram_binning_calibration_ops.cu ++ # src/input_combine_ops/input_combine.cu ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu ++ # src/memory_utils/memory_utils.cu ++ # src/memory_utils/memory_utils_ops.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu ++ # src/jagged_tensor_ops/dense_to_jagged_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu ++ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_softmax_backward.cu ++ # src/jagged_tensor_ops/jagged_softmax_forward.cu ++ # src/jagged_tensor_ops/jagged_tensor_ops.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu ++ # src/jagged_tensor_ops/jagged_unique_indices.cu ++ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu ++ # src/layout_transform_ops/layout_transform_ops.cu ++ # src/metric_ops/metric_ops.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu + src/quantize_ops/quantize_bfloat16.cu + src/quantize_ops/quantize_fp8_rowwise.cu + src/quantize_ops/quantize_fused_8bit_rowwise.cu +@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY) + src/quantize_ops/quantize_msfp.cu + src/quantize_ops/quantize_padded_fp8_rowwise.cu + src/quantize_ops/quantize_mx.cu +- src/sparse_ops/sparse_async_cumsum.cu +- src/sparse_ops/sparse_block_bucketize_features.cu +- src/sparse_ops/sparse_bucketize_features.cu +- src/sparse_ops/sparse_batched_unary_embeddings.cu +- src/sparse_ops/sparse_compute_frequency_sequence.cu +- src/sparse_ops/sparse_expand_into_jagged_permute.cu +- src/sparse_ops/sparse_group_index.cu +- src/sparse_ops/sparse_index_add.cu +- src/sparse_ops/sparse_index_select.cu +- src/sparse_ops/sparse_invert_permute.cu +- src/sparse_ops/sparse_pack_segments_backward.cu +- src/sparse_ops/sparse_pack_segments_forward.cu +- src/sparse_ops/sparse_permute_1d.cu +- src/sparse_ops/sparse_permute_2d.cu +- src/sparse_ops/sparse_permute102.cu +- src/sparse_ops/sparse_permute_embeddings.cu +- src/sparse_ops/sparse_range.cu +- src/sparse_ops/sparse_reorder_batched_ad.cu +- src/sparse_ops/sparse_segment_sum_csr.cu +- src/sparse_ops/sparse_zipf.cu +- src/split_embeddings_cache/lfu_cache_find.cu +- src/split_embeddings_cache/lfu_cache_populate.cu +- src/split_embeddings_cache/lfu_cache_populate_byte.cu +- src/split_embeddings_cache/lru_cache_find.cu +- src/split_embeddings_cache/lru_cache_populate.cu +- src/split_embeddings_cache/lru_cache_populate_byte.cu +- src/split_embeddings_cache/lxu_cache.cu +- src/split_embeddings_cache/linearize_cache_indices.cu +- src/split_embeddings_cache/reset_weight_momentum.cu +- src/split_embeddings_utils/generate_vbe_metadata.cu +- src/split_embeddings_utils/get_infos_metadata.cu +- src/split_embeddings_utils/radix_sort_pairs.cu +- src/split_embeddings_utils/transpose_embedding_input.cu) ++ # src/sparse_ops/sparse_async_cumsum.cu ++ # src/sparse_ops/sparse_block_bucketize_features.cu ++ # src/sparse_ops/sparse_bucketize_features.cu ++ # src/sparse_ops/sparse_batched_unary_embeddings.cu ++ # src/sparse_ops/sparse_compute_frequency_sequence.cu ++ # src/sparse_ops/sparse_expand_into_jagged_permute.cu ++ # src/sparse_ops/sparse_group_index.cu ++ # src/sparse_ops/sparse_index_add.cu ++ # src/sparse_ops/sparse_index_select.cu ++ # src/sparse_ops/sparse_invert_permute.cu ++ # src/sparse_ops/sparse_pack_segments_backward.cu ++ # src/sparse_ops/sparse_pack_segments_forward.cu ++ # src/sparse_ops/sparse_permute_1d.cu ++ # src/sparse_ops/sparse_permute_2d.cu ++ # src/sparse_ops/sparse_permute102.cu ++ # src/sparse_ops/sparse_permute_embeddings.cu ++ # src/sparse_ops/sparse_range.cu ++ # src/sparse_ops/sparse_reorder_batched_ad.cu ++ # src/sparse_ops/sparse_segment_sum_csr.cu ++ # src/sparse_ops/sparse_zipf.cu ++ # src/split_embeddings_cache/lfu_cache_find.cu ++ # src/split_embeddings_cache/lfu_cache_populate.cu ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cu ++ # src/split_embeddings_cache/lru_cache_find.cu ++ # src/split_embeddings_cache/lru_cache_populate.cu ++ # src/split_embeddings_cache/lru_cache_populate_byte.cu ++ # src/split_embeddings_cache/lxu_cache.cu ++ # src/split_embeddings_cache/linearize_cache_indices.cu ++ # src/split_embeddings_cache/reset_weight_momentum.cu ++ # src/split_embeddings_utils/generate_vbe_metadata.cu ++ # src/split_embeddings_utils/get_infos_metadata.cu ++ # src/split_embeddings_utils/radix_sort_pairs.cu ++ # src/split_embeddings_utils/transpose_embedding_input.cu) ++ ) + + set_source_files_properties(${fbgemm_gpu_sources_static_gpu} + PROPERTIES COMPILE_OPTIONS +diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +index 01f1d6ab..a6b8d7a8 100644 +--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt ++++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories + ${THIRDPARTY}/json/include + ${NCCL_INCLUDE_DIRS}) + +-set(attention_ops_sources +- src/attention/attention.cpp +- src/attention/gqa_attn_splitk.cu) ++# set(attention_ops_sources ++# src/attention/attention.cpp ++# src/attention/gqa_attn_splitk.cu) + + set(quantize_ops_sources + src/quantize/cutlass_extensions.cu + src/quantize/quantize.cu + src/quantize/quantize.cpp) + +-set(comm_ops_sources +- src/comm/car.cu +- src/comm/car.cpp) ++# set(comm_ops_sources ++# src/comm/car.cu ++# src/comm/car.cpp) + + set(experimental_gen_ai_cpp_source_files +- ${attention_ops_sources} ++ # ${attention_ops_sources} + ${quantize_ops_sources} +- ${comm_ops_sources}) ++ # ${comm_ops_sources} ++) + + set_source_files_properties(${experimental_gen_ai_cpp_source_files} + PROPERTIES INCLUDE_DIRECTORIES diff --git a/server/fix_torch90a.sh b/server/fix_torch90a.sh new file mode 100755 index 00000000..5e444828 --- /dev/null +++ b/server/fix_torch90a.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# This script is required to patch torch < 2.4 +# It adds the 90a cuda target (H100) +# This target is required to build FBGEMM kernels + +torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|') + +sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch +sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch +sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index fe839cf4..8ec2a5ae 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -8,6 +8,7 @@ from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.log import log_master app = typer.Typer() @@ -87,15 +88,17 @@ def serve( ) if len(lora_adapter_ids) > 0: - logger.warning( - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + log_master( + logger.warning, + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", ) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: - logger.warning( - f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." + log_master( + logger.warning, + f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", ) global CUDA_GRAPHS CUDA_GRAPHS = None diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 99c490d5..54da63e8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,6 +3,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -136,7 +137,10 @@ if ENGINE != "triton": try: import flash_attn_2_cuda - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) except ImportError as e: if major >= 8: architecture_suffix = f"-{SYSTEM}" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index 925b0b2d..aae2bd1a 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -4,19 +4,11 @@ from functools import lru_cache import bitsandbytes as bnb import torch from bitsandbytes.nn import Int8Params, Params4bit -from loguru import logger -from text_generation_server.utils.weights import Weight - - -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class BNBWeight(Weight): +class BNBWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module): @dataclass -class BNBFP4Weight(Weight): +class BNBFP4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -90,7 +82,7 @@ class BNBFP4Weight(Weight): @dataclass -class BNBNF4Weight(Weight): +class BNBNF4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index f003f914..b1e5235a 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -2,11 +2,11 @@ from dataclasses import dataclass import torch from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class EETQWeight(Weight): +class EETQWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index b56f568a..cdf16d6b 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,8 +1,29 @@ -from dataclasses import dataclass - import torch + +from dataclasses import dataclass +from typing import Optional, Union, List +from loguru import logger + from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once + +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False +try: + import fbgemm_gpu.experimental.gen_ai + + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 +except (ImportError, ModuleNotFoundError): + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") def get_fp8_linear() -> torch.nn.Module: @@ -21,12 +42,17 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device +def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): + if FBGEMM_DYN_AVAILABLE: + qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( + weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype + ) + return qweight, scale + # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -38,27 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = torch.cat(w, dim=dim) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=0) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + @dataclass class Fp8Weight(Weight): weight: torch.Tensor + dtype: torch.dtype + weight_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None def get_linear(self, bias: torch.Tensor): - return get_fp8_linear()(self.weight, bias) + if self.weight_scale is None: + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear().from_fp8( + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + ) class Fp8Linear(torch.nn.Module): def __init__( self, - weight, + qweight, + scale, + scale_upper_bound, bias, + dtype, ) -> None: super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) self.bias = bias if bias is not None else None + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, scale, input_scale, bias, dtype): + return cls( + qweight=weight, + scale=scale, + scale_upper_bound=input_scale, + bias=bias, + dtype=dtype, + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: + if FBGEMM_MM_AVAILABLE: + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound + ) + + y = torch.ops.fbgemm.f8f8bf16_rowwise( + qinput, + self.qweight, + scale, + self.scale, + use_fast_accum=True, + bias=self.bias, + ) + return y.to(self.dtype) + qinput, scale = fp8_quantize(input) output, _ = torch._scaled_mm( qinput, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a913ff57..40271c35 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, - weight: torch.Tensor, + qweight: torch.Tensor, + scale: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - qweight, scale = fp8_quantize(weight) scale = scale.to(torch.float16) qweight, scales = repack_fp8_for_marlin(qweight, scale) @@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module): out_features // 64 * 16, dtype=torch.int, device=qweight.device ) + @classmethod + def from_unquant(cls, weight, bias, _dtype): + qweight, scale = fp8_quantize(weight) + return cls(qweight=qweight, scale=scale, bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): + return cls(qweight=weight, scale=scale, bias=bias) + def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 690a8887..a43cdfed 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_master # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -47,9 +48,7 @@ torch.set_grad_enabled(False) __all__ = [ "Model", - "BLOOMSharded", "CausalLM", - "GalacticaSharded", "Seq2SeqLM", "get_model", ] @@ -125,7 +124,7 @@ try: ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False @@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True try: from text_generation_server.models.mamba import Mamba except ImportError as e: - logger.warning(f"Could not import Mamba: {e}") + log_master(logger.warning, f"Could not import Mamba: {e}") MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: @@ -311,6 +310,12 @@ def get_model( if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 + elif quantize == "fp8": + from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + + if FBGEMM_MM_AVAILABLE: + # fbgemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. @@ -433,7 +438,9 @@ def get_model( speculate = get_speculate() if speculate > 0: - logger.info(f"Using speculation {method} with {speculate} input ids.") + log_master( + logger.info, f"Using speculation {method} with {speculate} input ids." + ) if model_type is None: # TODO: fix how we determine model type for Mamba @@ -448,10 +455,10 @@ def get_model( if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: - logger.info(f"Auto selecting quantization method {method}") + log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method else: - logger.info(f"Unknown quantization method {method}") + log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( @@ -593,7 +600,7 @@ def get_model( ) except RuntimeError as e: # Lots of legacy models with various weight names. - logger.warning(f"Couldn't load flash gpt2 variant: {e}") + log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") return CausalLM.fallback( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5237a484..f7980d2d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -42,16 +41,15 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import ( - DefaultWeightsLoader, UnquantizedWeight, Weights, ) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader if SYSTEM == "rocm": try: @@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id): @contextmanager def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" weights_loader = weights.weights_loader - if ( - isinstance(weights_loader, DefaultWeightsLoader) - and weights_loader.weight_class is Fp8Weight - ): - weights_loader = DefaultWeightsLoader(UnquantizedWeight) + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) with weights.use_loader(weights_loader): yield @@ -418,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.layers = nn.ModuleList( + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else "{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + self.layers.extend( [ FlashLlamaLayer( index=layer_id, @@ -430,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module): config=config, weights=weights, ) - for layer_id in range(config.num_hidden_layers) + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1) ] ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + self.norm = FastRMSNorm.load( prefix="model.norm" if not prefix else f"{prefix}.model.norm", weights=weights, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2888f1f7..cfffafa1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model +from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - hub, ) from text_generation_server.models.types import ( Batch, @@ -1156,31 +1155,36 @@ class FlashCausalLM(Model): f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) - logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." + log_master( + logger.info, + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): - logger.info( - f"The file {tunableop_filepath} already exists and will be reused." + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", ) torch.cuda.tunable.read_file(tunableop_filepath) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info( - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) if CUDA_GRAPHS: try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: @@ -1188,7 +1192,9 @@ class FlashCausalLM(Model): except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) return int(num_blocks * BLOCK_SIZE) @@ -1540,8 +1546,7 @@ class FlashCausalLM(Model): left = 0 if n_accepted_ids > 1: - if RANK == 0: - logger.debug(f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06035ccd..ac42df30 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,15 +1,16 @@ import torch import os from loguru import logger -from typing import Dict +from typing import Dict, Optional + +from text_generation_server.utils.log import log_master MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - logger.info("Using FLASH_DECODING") - + log_master(logger.info, "Using FLASH_DECODING") cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: @@ -26,11 +27,9 @@ else: if cuda_graphs is not None: cuda_graphs.sort(reverse=True) - CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. -global MODEL_ID MODEL_ID = None @@ -41,8 +40,7 @@ def set_model_id(model_id: str): # NOTE: eventually we should move this into the router and pass back the # index in all cases. -global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX: Dict[str, int] = None +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 09130b85..e7748bb9 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import ( AdapterParameters, AdapterSource, ) +from text_generation_server.utils.log import log_master from loguru import logger @@ -204,8 +205,9 @@ class Model(ABC): f"order to use the dynamic adapter loading feature." ) - logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + log_master( + logger.info, + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", ) weight_names = tuple([v[0] for v in self.target_to_layer.values()]) ( @@ -240,8 +242,9 @@ class Model(ABC): layer_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: - logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + log_master( + logger.warning, + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f869f8b5..308d5a3d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,4 +1,3 @@ -from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils.log import log_master from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}" + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 36d63e86..82aeba6c 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -56,7 +56,7 @@ def initialize_torch_distributed(): backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) + options._timeout = timedelta(seconds=120) else: backend = "gloo" options = None @@ -76,7 +76,7 @@ def initialize_torch_distributed(): backend="ccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: @@ -84,7 +84,7 @@ def initialize_torch_distributed(): backend=backend, world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py index b1456f1e..4385c71e 100644 --- a/server/text_generation_server/utils/log.py +++ b/server/text_generation_server/utils/log.py @@ -1,6 +1,15 @@ from functools import lru_cache +from text_generation_server.utils.dist import RANK @lru_cache(10) -def log_once(log, msg: str): - log(msg) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index e8e22db8..c3c038fe 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -11,6 +11,7 @@ from text_generation_server.utils.weights import ( ) +# TODO: Split this config to have a single config type per quant method @dataclass class _QuantizerConfig: bits: int @@ -21,6 +22,11 @@ class _QuantizerConfig: sym: bool +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + # We should probably do this with Pytantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): @@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision): filename = hf_hub_download(model_id, filename=filename, revision=revision) with open(filename, "r") as f: data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models @@ -99,6 +112,12 @@ def get_loader( if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return GPTQWeightsLoader( bits=quantizer_config.bits, desc_act=quantizer_config.desc_act, @@ -127,18 +146,28 @@ def get_loader( from text_generation_server.layers.exl2 import Exl2WeightsLoader return Exl2WeightsLoader() - elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Weight - - return DefaultWeightsLoader(Fp8Weight) elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return MarlinWeightsLoader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) - elif quantize is None: - return DefaultWeightsLoader(UnquantizedWeight) + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 91592df0..66bb6051 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,12 +1,12 @@ +import torch + from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass -from enum import Enum, auto from pathlib import Path -from typing import Dict, List, Optional, Union - -import torch +from typing import Dict, List, Optional, Union, Type from safetensors import safe_open +from dataclasses import dataclass + from text_generation_server.utils.import_utils import SYSTEM @@ -84,7 +84,7 @@ class Weight(ABC): @dataclass -class UnquantizedWeight: +class UnquantizedWeight(Weight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -99,7 +99,7 @@ class UnquantizedWeight: class DefaultWeightsLoader(WeightsLoader): """Weight loader that loads (unquantized) Torch tensors.""" - def __init__(self, weight_class): + def __init__(self, weight_class: Type[UnquantizedWeight]): """Create a loader. Weights will be wrapped using the given `weights_class`, normally this will be `UnquantizedWeight`, but a quantizer-specific class such as `Fp8Weight` can be used to quantize the weights during loading. @@ -208,20 +208,29 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 - # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + # as well. FP8 uses torch.float8_e4m3fn + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -241,12 +250,16 @@ class Weights: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. - if tensor.dtype not in (torch.int16, torch.int32): + # FP8 uses torch.float8_e4m3fn. + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -255,10 +268,14 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -304,7 +321,16 @@ class Weights: tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) return tensor From d13215da8fcf501b56d84d649808aa7cbf979be6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Sun, 21 Jul 2024 16:48:04 +0000 Subject: [PATCH 147/340] fix(server): fix deepseekv2 loading (#2266) --- .../models/custom_modeling/flash_deepseek_v2_modeling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 8ffbd143..f5b2ba0e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -34,7 +34,6 @@ from text_generation_server.layers.attention.common import Seqlen from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weights from torch import nn from transformers.activations import ACT2FN @@ -240,7 +239,6 @@ class DeepseekV2Attention(torch.nn.Module): if config.attention_bias else None ), - quantize=config.quantize, ) self.q_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.q_a_layernorm", @@ -261,7 +259,6 @@ class DeepseekV2Attention(torch.nn.Module): if config.attention_bias else None ), - quantize=config.quantize, ) self.kv_a_layernorm = FastRMSNorm.load( From a5aee82a6974f3318541831cc02415057cc37557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?icyboy=E2=84=A2?= Date: Mon, 22 Jul 2024 17:31:00 +0800 Subject: [PATCH 148/340] Hotfix: fix of use of unquantized weights in Mixtral GQA loading (#2269) * Update idefics_causal_lm.py Fix syntax issues * fix dbrx & opt model prefix bug * Hotfix: fix of use of unquantized weights in Mixtral GQA loading --- .../models/custom_modeling/flash_mixtral_modeling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index a1e36fc7..aae6737e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -52,6 +52,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight class MixtralConfig(PretrainedConfig): @@ -138,16 +139,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) From 758a8b84234872bdb7a51f6c4daa4ed6681f1c00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Mon, 22 Jul 2024 12:00:17 +0200 Subject: [PATCH 149/340] legacy warning on text_generation client (#2271) Update README.md point to huggingface_hub inference clients instead --- clients/python/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/clients/python/README.md b/clients/python/README.md index bf37508e..88239aa1 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,3 +1,6 @@ +# Legacy warning ⚠️ +The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`. + # Text Generation The Hugging Face Text Generation Python library provides a convenient way of interfacing with a From a7515b8af11a91577a93957a4df78186f7d546e7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 22 Jul 2024 15:51:32 +0000 Subject: [PATCH 150/340] fix(server): fix fp8 weight loading (#2268) * fix(server): fix fp8 weight loading * fixed scales loading * update snap * revert default dtype --- .../test_flash_llama_fp8_all_params.json | 50 +++++++++---------- server/text_generation_server/layers/fp8.py | 20 +++++--- .../text_generation_server/layers/marlin.py | 26 +++++----- .../text_generation_server/models/__init__.py | 35 +++++++------ .../text_generation_server/utils/weights.py | 13 +++-- 5 files changed, 82 insertions(+), 62 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json index dcb4d063..bf981e4f 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.5625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.375, "text": " request" } ], @@ -24,66 +24,66 @@ "tokens": [ { "id": 25, - "logprob": -0.8535156, + "logprob": -0.8984375, "special": false, "text": ":" }, { "id": 2209, - "logprob": -2.4804688, + "logprob": -2.78125, "special": false, "text": " Is" }, { "id": 279, - "logprob": -0.7167969, + "logprob": -0.6328125, "special": false, "text": " the" }, { "id": 734, - "logprob": -2.625, + "logprob": -2.703125, "special": false, "text": " function" }, { "id": 330, - "logprob": -0.35131836, + "logprob": -0.34179688, "special": false, "text": " \"" }, { "id": 4110, - "logprob": -2.4101562, + "logprob": -2.359375, "special": false, "text": "Create" }, { - "id": 264, - "logprob": -0.23181152, + "id": 7575, + "logprob": -2.1875, "special": false, - "text": " a" - }, - { - "id": 502, - "logprob": -0.25512695, - "special": false, - "text": " new" - }, - { - "id": 1052, - "logprob": -1.2792969, - "special": false, - "text": " file" + "text": "Process" }, { "id": 1, - "logprob": -1.2529297, + "logprob": -0.07910156, "special": false, "text": "\"" + }, + { + "id": 304, + "logprob": -0.83203125, + "special": false, + "text": " in" + }, + { + "id": 12468, + "logprob": -1.8203125, + "special": false, + "text": " Win" } ], "top_tokens": null }, - "generated_text": "Test request: Is the function \"Create a new file\"" + "generated_text": "Test request: Is the function \"CreateProcess\" in Win" } diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index cdf16d6b..bf5a0989 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch scale = weights.get_packed_sharded( f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ) + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -115,8 +117,12 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - w = torch.cat(w, dim=dim) + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: @@ -124,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) for p in prefixes ] - scale = torch.cat(scale, dim=0) + scale = torch.cat(scale, dim=0).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -140,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 40271c35..a28012da 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -504,7 +504,7 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, qweight: torch.Tensor, - scale: torch.Tensor, + scales: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -514,8 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - scale = scale.to(torch.float16) - qweight, scales = repack_fp8_for_marlin(qweight, scale) + scales = scales.unsqueeze(0) + if scales.shape[1] == 1: + out_features, in_features = qweight.shape + scales = scales.repeat(1, out_features) + qweight, scales = repack_fp8_for_marlin(qweight, scales) in_features = qweight.shape[0] * MARLIN_TILE_SIZE out_features = scales.shape[1] @@ -530,13 +533,13 @@ class GPTQMarlinFP8Linear(nn.Module): ) @classmethod - def from_unquant(cls, weight, bias, _dtype): - qweight, scale = fp8_quantize(weight) - return cls(qweight=qweight, scale=scale, bias=bias) + def from_unquant(cls, weight, bias, dtype): + qweight, scales = fp8_quantize(weight) + return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): - return cls(qweight=weight, scale=scale, bias=bias) + def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None @@ -591,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: return packed -def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): +def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): """ Repack FP8 tensor for GPTQ-Marlin. """ @@ -608,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): qweight, perm, in_features, out_features, 8 ) - scales = scale.reshape(1, 1).repeat(1, out_features) scales = permute_scales(scales) return repacked, scales @@ -621,7 +623,7 @@ class MarlinWeight(Weight): Attributes: B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): float16 scales. + s (torch.Tensor): bfloat16/float16 scales. """ B: torch.Tensor @@ -629,7 +631,7 @@ class MarlinWeight(Weight): def __post_init__(self): assert self.B.dtype == torch.int32 - assert self.s.dtype == torch.float16 + assert self.s.dtype in [torch.float16, torch.bfloat16] def get_linear(self, bias: torch.Tensor): return MarlinLinear(weight=self, bias=bias) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a43cdfed..1cd13a2a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -306,14 +306,32 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq", "exl2"}: + log_master(logger.info, f"Auto selecting quantization method {method}") + quantize = method + elif method == "fbgemm_fp8": + log_master(logger.info, "Auto selecting quantization method fp8") + quantize = "fp8" + else: + log_master(logger.warning, f"Unknown quantization method {method}") + if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 elif quantize == "fp8": - from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE - if FBGEMM_MM_AVAILABLE: + if FBGEMM_DYN_AVAILABLE: # fbgemm kernels are fp8xfp8->bf16 dtype = torch.bfloat16 else: @@ -332,11 +350,6 @@ def get_model( else: set_speculate(0) - config_dict, _ = PretrainedConfig.get_config_dict( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config_dict.get("model_type", None) - speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -451,14 +464,6 @@ def get_model( raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" ) - quantization_config = config_dict.get("quantization_config", None) - if quantization_config is not None and quantize is None: - method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq", "exl2"}: - log_master(logger.info, f"Auto selecting quantization method {method}") - quantize = method - else: - log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 66bb6051..108ced48 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -230,7 +230,9 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): + def get_partial_sharded( + self, tensor_name: str, dim: int, to_device=True, to_dtype=True + ): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -256,10 +258,11 @@ class Weights: and to_dtype ): tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): + def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -268,7 +271,9 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) + return self.get_partial_sharded( + tensor_name, dim, to_device=to_device, to_dtype=to_dtype + ) def get_packed_sharded( self, From 568cc9f3d07945a20d77a4f63a06f1f603b7ccb7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 22 Jul 2024 18:27:10 +0200 Subject: [PATCH 151/340] Softcapping for gemma2. (#2273) * Softcapping for gemma2. * Less clutter. * No access to transformers config, only config_dict here. * 0.0 is the null value in the C++ API. --- .../test_flash_gemma2/test_flash_gemma2.json | 254 ++++ .../test_flash_gemma2_load.json | 1018 +++++++++++++++ integration-tests/models/test_flash_gemma2.py | 46 + launcher/src/main.rs | 6 + server/Makefile-flash-att-v2 | 2 +- server/poetry.lock | 1135 +++++++++-------- server/pyproject.toml | 2 +- server/requirements_cuda.txt | 46 +- server/requirements_intel.txt | 46 +- server/requirements_rocm.txt | 46 +- .../layers/attention/cuda.py | 14 + .../text_generation_server/models/__init__.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 10 + 13 files changed, 2012 insertions(+), 615 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json create mode 100644 integration-tests/models/test_flash_gemma2.py diff --git a/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json new file mode 100644 index 00000000..1e9a50cf --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json @@ -0,0 +1,254 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.15625, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.1171875, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.84375, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.546875, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.65625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.84375, + "text": " " + }, + { + "id": 235274, + "logprob": -0.45117188, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07421875, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.4140625, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009536743, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.033203125, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.0002670288, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.75, + "text": "," + }, + { + "id": 7385, + "logprob": -11.625, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.29882812, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16699219, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.5, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.125, + "text": ":" + }, + { + "id": 108, + "logprob": -3.421875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.099121094, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.078125, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.025756836, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.29101562, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035858154, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.1007996e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json new file mode 100644 index 00000000..5c47dd3c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json @@ -0,0 +1,1018 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.25, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.111816406, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.78125, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.59375, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.7421875, + "text": " " + }, + { + "id": 235274, + "logprob": -0.44921875, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07128906, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.40429688, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009918213, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.03540039, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.00028800964, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.71875, + "text": "," + }, + { + "id": 7385, + "logprob": -11.875, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.3046875, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16113281, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.25, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.1875, + "text": ":" + }, + { + "id": 108, + "logprob": -3.296875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08886719, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.09375, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035552979, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.220009e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.25, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.111816406, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.78125, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.59375, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.7421875, + "text": " " + }, + { + "id": 235274, + "logprob": -0.44921875, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07128906, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.40429688, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009918213, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.03540039, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.00028800964, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.71875, + "text": "," + }, + { + "id": 7385, + "logprob": -11.875, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.3046875, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16113281, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.25, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.1875, + "text": ":" + }, + { + "id": 108, + "logprob": -3.296875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08886719, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.09375, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035552979, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.220009e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.15625, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.1171875, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.84375, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.546875, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.65625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.84375, + "text": " " + }, + { + "id": 235274, + "logprob": -0.45117188, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07421875, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.4140625, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009536743, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.033203125, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.0002670288, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.75, + "text": "," + }, + { + "id": 7385, + "logprob": -11.625, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.29882812, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16699219, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.5, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.125, + "text": ":" + }, + { + "id": 108, + "logprob": -3.421875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08984375, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.1015625, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0038452148, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.1484833e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.25, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.111816406, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.78125, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.59375, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.7421875, + "text": " " + }, + { + "id": 235274, + "logprob": -0.44921875, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07128906, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.40429688, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009918213, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.03540039, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.00028800964, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.71875, + "text": "," + }, + { + "id": 7385, + "logprob": -11.875, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.3046875, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16113281, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.25, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.1875, + "text": ":" + }, + { + "id": 108, + "logprob": -3.296875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08886719, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.09375, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035552979, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.220009e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + } +] diff --git a/integration-tests/models/test_flash_gemma2.py b/integration-tests/models/test_flash_gemma2.py new file mode 100644 index 00000000..547db493 --- /dev/null +++ b/integration-tests/models/test_flash_gemma2.py @@ -0,0 +1,46 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma2_handle(launcher): + with launcher("google/gemma-2-9b-it", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma2(flash_gemma2_handle): + await flash_gemma2_handle.health(300) + return flash_gemma2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma2(flash_gemma2, response_snapshot): + response = await flash_gemma2.generate( + "user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.\nmodel:\n", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.generated_text == "**Hydrogen**, light and free,\n**He" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot): + responses = await generate_load( + flash_gemma2, + "user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.\nmodel:\n", + max_new_tokens=10, + n=4, + ) + + assert responses[0].generated_text == "**Hydrogen**, light and free,\n**He" + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index dc5ebaee..82dcf379 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -27,6 +27,7 @@ mod env_runtime; struct RawConfig { max_position_embeddings: Option, n_positions: Option, + model_type: Option, max_seq_len: Option, } @@ -1432,6 +1433,11 @@ fn main() -> Result<(), LauncherError> { let content = std::fs::read_to_string(filename)?; let config: RawConfig = serde_json::from_str(&content)?; + + if config.model_type == Some("gemma2".to_string()) { + tracing::info!("Forcing flash decoding because of softcap usage"); + std::env::set_var("FLASH_DECODING", "1"); + } let config: Config = config.into(); // Quantization usually means you're even more RAM constrained. diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index ba90a74d..dbddd0f4 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.5.9.post1 +flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 build-flash-attention-v2-cuda: diff --git a/server/poetry.lock b/server/poetry.lock index 4984978a..3dbd24e2 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -202,13 +202,13 @@ test = ["scipy"] [[package]] name = "certifi" -version = "2024.6.2" +version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, - {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, + {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, + {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, ] [[package]] @@ -348,45 +348,47 @@ files = [ [[package]] name = "datasets" -version = "2.14.4" +version = "2.20.0" description = "HuggingFace community-driven open-source library of datasets" optional = true python-versions = ">=3.8.0" files = [ - {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, - {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, + {file = "datasets-2.20.0-py3-none-any.whl", hash = "sha256:76ac02e3bdfff824492e20678f0b6b1b6d080515957fe834b00c2ba8d6b18e5e"}, + {file = "datasets-2.20.0.tar.gz", hash = "sha256:3c4dbcd27e0f642b9d41d20ff2efa721a5e04b32b2ca4009e0fc9139e324553f"}, ] [package.dependencies] aiohttp = "*" -dill = ">=0.3.0,<0.3.8" -fsspec = {version = ">=2021.11.1", extras = ["http"]} -huggingface-hub = ">=0.14.0,<1.0.0" +dill = ">=0.3.0,<0.3.9" +filelock = "*" +fsspec = {version = ">=2023.1.0,<=2024.5.0", extras = ["http"]} +huggingface-hub = ">=0.21.2" multiprocess = "*" numpy = ">=1.17" packaging = "*" pandas = "*" -pyarrow = ">=8.0.0" +pyarrow = ">=15.0.0" +pyarrow-hotfix = "*" pyyaml = ">=5.1" -requests = ">=2.19.0" -tqdm = ">=4.62.1" +requests = ">=2.32.2" +tqdm = ">=4.66.3" xxhash = "*" [package.extras] -apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] -docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] -jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +dev = ["Pillow (>=9.4.0)", "absl-py", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] -quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] +quality = ["ruff (>=0.3.0)"] s3 = ["s3fs"] -tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] -tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +tensorflow = ["tensorflow (>=2.6.0)"] +tensorflow-gpu = ["tensorflow (>=2.6.0)"] +tests = ["Pillow (>=9.4.0)", "absl-py", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] -vision = ["Pillow (>=6.2.1)"] +vision = ["Pillow (>=9.4.0)"] [[package]] name = "deprecated" @@ -407,17 +409,18 @@ dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] [[package]] name = "dill" -version = "0.3.7" +version = "0.3.8" description = "serialize all of Python" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, - {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, ] [package.extras] graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] [[package]] name = "diskcache" @@ -443,13 +446,13 @@ files = [ [[package]] name = "exceptiongroup" -version = "1.2.1" +version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, - {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, ] [package.extras] @@ -457,18 +460,18 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.4" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -559,13 +562,13 @@ files = [ [[package]] name = "fsspec" -version = "2024.6.0" +version = "2024.5.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.0-py3-none-any.whl", hash = "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee"}, - {file = "fsspec-2024.6.0.tar.gz", hash = "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2"}, + {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"}, + {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"}, ] [package.dependencies] @@ -577,7 +580,6 @@ adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] dev = ["pre-commit", "ruff"] -doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] @@ -601,17 +603,17 @@ tqdm = ["tqdm"] [[package]] name = "googleapis-common-protos" -version = "1.63.1" +version = "1.63.2" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" files = [ - {file = "googleapis-common-protos-1.63.1.tar.gz", hash = "sha256:c6442f7a0a6b2a80369457d79e6672bb7dcbaab88e0848302497e3ec80780a6a"}, - {file = "googleapis_common_protos-1.63.1-py2.py3-none-any.whl", hash = "sha256:0e1c2cdfcbc354b76e4a211a35ea35d6926a835cba1377073c4861db904a1877"}, + {file = "googleapis-common-protos-1.63.2.tar.gz", hash = "sha256:27c5abdffc4911f28101e635de1533fb4cfd2c37fbaa9174587c799fac90aa87"}, + {file = "googleapis_common_protos-1.63.2-py2.py3-none-any.whl", hash = "sha256:27a2499c7e8aff199665b22741997e485eccc8645aa9176c7c988e6fae507945"}, ] [package.dependencies] -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] @@ -635,61 +637,61 @@ testing = ["protobuf (>=4.21.9)"] [[package]] name = "grpcio" -version = "1.64.1" +version = "1.65.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.64.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:55697ecec192bc3f2f3cc13a295ab670f51de29884ca9ae6cd6247df55df2502"}, - {file = "grpcio-1.64.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3b64ae304c175671efdaa7ec9ae2cc36996b681eb63ca39c464958396697daff"}, - {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:bac71b4b28bc9af61efcdc7630b166440bbfbaa80940c9a697271b5e1dabbc61"}, - {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c024ffc22d6dc59000faf8ad781696d81e8e38f4078cb0f2630b4a3cf231a90"}, - {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7cd5c1325f6808b8ae31657d281aadb2a51ac11ab081ae335f4f7fc44c1721d"}, - {file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0a2813093ddb27418a4c99f9b1c223fab0b053157176a64cc9db0f4557b69bd9"}, - {file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2981c7365a9353f9b5c864595c510c983251b1ab403e05b1ccc70a3d9541a73b"}, - {file = "grpcio-1.64.1-cp310-cp310-win32.whl", hash = "sha256:1262402af5a511c245c3ae918167eca57342c72320dffae5d9b51840c4b2f86d"}, - {file = "grpcio-1.64.1-cp310-cp310-win_amd64.whl", hash = "sha256:19264fc964576ddb065368cae953f8d0514ecc6cb3da8903766d9fb9d4554c33"}, - {file = "grpcio-1.64.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:58b1041e7c870bb30ee41d3090cbd6f0851f30ae4eb68228955d973d3efa2e61"}, - {file = "grpcio-1.64.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bbc5b1d78a7822b0a84c6f8917faa986c1a744e65d762ef6d8be9d75677af2ca"}, - {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5841dd1f284bd1b3d8a6eca3a7f062b06f1eec09b184397e1d1d43447e89a7ae"}, - {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8caee47e970b92b3dd948371230fcceb80d3f2277b3bf7fbd7c0564e7d39068e"}, - {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73819689c169417a4f978e562d24f2def2be75739c4bed1992435d007819da1b"}, - {file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6503b64c8b2dfad299749cad1b595c650c91e5b2c8a1b775380fcf8d2cbba1e9"}, - {file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1de403fc1305fd96cfa75e83be3dee8538f2413a6b1685b8452301c7ba33c294"}, - {file = "grpcio-1.64.1-cp311-cp311-win32.whl", hash = "sha256:d4d29cc612e1332237877dfa7fe687157973aab1d63bd0f84cf06692f04c0367"}, - {file = "grpcio-1.64.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e56462b05a6f860b72f0fa50dca06d5b26543a4e88d0396259a07dc30f4e5aa"}, - {file = "grpcio-1.64.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59"}, - {file = "grpcio-1.64.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1"}, - {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22"}, - {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762"}, - {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1"}, - {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb"}, - {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb"}, - {file = "grpcio-1.64.1-cp312-cp312-win32.whl", hash = "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027"}, - {file = "grpcio-1.64.1-cp312-cp312-win_amd64.whl", hash = "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6"}, - {file = "grpcio-1.64.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:a011ac6c03cfe162ff2b727bcb530567826cec85eb8d4ad2bfb4bd023287a52d"}, - {file = "grpcio-1.64.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4d6dab6124225496010bd22690f2d9bd35c7cbb267b3f14e7a3eb05c911325d4"}, - {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:a5e771d0252e871ce194d0fdcafd13971f1aae0ddacc5f25615030d5df55c3a2"}, - {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c3c1b90ab93fed424e454e93c0ed0b9d552bdf1b0929712b094f5ecfe7a23ad"}, - {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20405cb8b13fd779135df23fabadc53b86522d0f1cba8cca0e87968587f50650"}, - {file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0cc79c982ccb2feec8aad0e8fb0d168bcbca85bc77b080d0d3c5f2f15c24ea8f"}, - {file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a3a035c37ce7565b8f4f35ff683a4db34d24e53dc487e47438e434eb3f701b2a"}, - {file = "grpcio-1.64.1-cp38-cp38-win32.whl", hash = "sha256:1257b76748612aca0f89beec7fa0615727fd6f2a1ad580a9638816a4b2eb18fd"}, - {file = "grpcio-1.64.1-cp38-cp38-win_amd64.whl", hash = "sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122"}, - {file = "grpcio-1.64.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:75dbbf415026d2862192fe1b28d71f209e2fd87079d98470db90bebe57b33179"}, - {file = "grpcio-1.64.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e3d9f8d1221baa0ced7ec7322a981e28deb23749c76eeeb3d33e18b72935ab62"}, - {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:5f8b75f64d5d324c565b263c67dbe4f0af595635bbdd93bb1a88189fc62ed2e5"}, - {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c84ad903d0d94311a2b7eea608da163dace97c5fe9412ea311e72c3684925602"}, - {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:940e3ec884520155f68a3b712d045e077d61c520a195d1a5932c531f11883489"}, - {file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309"}, - {file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac15b6c2c80a4d1338b04d42a02d376a53395ddf0ec9ab157cbaf44191f3ffdd"}, - {file = "grpcio-1.64.1-cp39-cp39-win32.whl", hash = "sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040"}, - {file = "grpcio-1.64.1-cp39-cp39-win_amd64.whl", hash = "sha256:ed6091fa0adcc7e4ff944090cf203a52da35c37a130efa564ded02b7aff63bcd"}, - {file = "grpcio-1.64.1.tar.gz", hash = "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a"}, + {file = "grpcio-1.65.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:3dc5f928815b8972fb83b78d8db5039559f39e004ec93ebac316403fe031a062"}, + {file = "grpcio-1.65.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:8333ca46053c35484c9f2f7e8d8ec98c1383a8675a449163cea31a2076d93de8"}, + {file = "grpcio-1.65.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:7af64838b6e615fff0ec711960ed9b6ee83086edfa8c32670eafb736f169d719"}, + {file = "grpcio-1.65.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbb64b4166362d9326f7efbf75b1c72106c1aa87f13a8c8b56a1224fac152f5c"}, + {file = "grpcio-1.65.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8422dc13ad93ec8caa2612b5032a2b9cd6421c13ed87f54db4a3a2c93afaf77"}, + {file = "grpcio-1.65.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:4effc0562b6c65d4add6a873ca132e46ba5e5a46f07c93502c37a9ae7f043857"}, + {file = "grpcio-1.65.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a6c71575a2fedf259724981fd73a18906513d2f306169c46262a5bae956e6364"}, + {file = "grpcio-1.65.1-cp310-cp310-win32.whl", hash = "sha256:34966cf526ef0ea616e008d40d989463e3db157abb213b2f20c6ce0ae7928875"}, + {file = "grpcio-1.65.1-cp310-cp310-win_amd64.whl", hash = "sha256:ca931de5dd6d9eb94ff19a2c9434b23923bce6f767179fef04dfa991f282eaad"}, + {file = "grpcio-1.65.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:bbb46330cc643ecf10bd9bd4ca8e7419a14b6b9dedd05f671c90fb2c813c6037"}, + {file = "grpcio-1.65.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d827a6fb9215b961eb73459ad7977edb9e748b23e3407d21c845d1d8ef6597e5"}, + {file = "grpcio-1.65.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:6e71aed8835f8d9fbcb84babc93a9da95955d1685021cceb7089f4f1e717d719"}, + {file = "grpcio-1.65.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a1c84560b3b2d34695c9ba53ab0264e2802721c530678a8f0a227951f453462"}, + {file = "grpcio-1.65.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27adee2338d697e71143ed147fe286c05810965d5d30ec14dd09c22479bfe48a"}, + {file = "grpcio-1.65.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f62652ddcadc75d0e7aa629e96bb61658f85a993e748333715b4ab667192e4e8"}, + {file = "grpcio-1.65.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:71a05fd814700dd9cb7d9a507f2f6a1ef85866733ccaf557eedacec32d65e4c2"}, + {file = "grpcio-1.65.1-cp311-cp311-win32.whl", hash = "sha256:b590f1ad056294dfaeac0b7e1b71d3d5ace638d8dd1f1147ce4bd13458783ba8"}, + {file = "grpcio-1.65.1-cp311-cp311-win_amd64.whl", hash = "sha256:12e9bdf3b5fd48e5fbe5b3da382ad8f97c08b47969f3cca81dd9b36b86ed39e2"}, + {file = "grpcio-1.65.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:54cb822e177374b318b233e54b6856c692c24cdbd5a3ba5335f18a47396bac8f"}, + {file = "grpcio-1.65.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:aaf3c54419a28d45bd1681372029f40e5bfb58e5265e3882eaf21e4a5f81a119"}, + {file = "grpcio-1.65.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:557de35bdfbe8bafea0a003dbd0f4da6d89223ac6c4c7549d78e20f92ead95d9"}, + {file = "grpcio-1.65.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8bfd95ef3b097f0cc86ade54eafefa1c8ed623aa01a26fbbdcd1a3650494dd11"}, + {file = "grpcio-1.65.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e6a8f3d6c41e6b642870afe6cafbaf7b61c57317f9ec66d0efdaf19db992b90"}, + {file = "grpcio-1.65.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1faaf7355ceed07ceaef0b9dcefa4c98daf1dd8840ed75c2de128c3f4a4d859d"}, + {file = "grpcio-1.65.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:60f1f38eed830488ad2a1b11579ef0f345ff16fffdad1d24d9fbc97ba31804ff"}, + {file = "grpcio-1.65.1-cp312-cp312-win32.whl", hash = "sha256:e75acfa52daf5ea0712e8aa82f0003bba964de7ae22c26d208cbd7bc08500177"}, + {file = "grpcio-1.65.1-cp312-cp312-win_amd64.whl", hash = "sha256:ff5a84907e51924973aa05ed8759210d8cdae7ffcf9e44fd17646cf4a902df59"}, + {file = "grpcio-1.65.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:1fbd6331f18c3acd7e09d17fd840c096f56eaf0ef830fbd50af45ae9dc8dfd83"}, + {file = "grpcio-1.65.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:de5b6be29116e094c5ef9d9e4252e7eb143e3d5f6bd6d50a78075553ab4930b0"}, + {file = "grpcio-1.65.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e4a3cdba62b2d6aeae6027ae65f350de6dc082b72e6215eccf82628e79efe9ba"}, + {file = "grpcio-1.65.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:941c4869aa229d88706b78187d60d66aca77fe5c32518b79e3c3e03fc26109a2"}, + {file = "grpcio-1.65.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f40cebe5edb518d78b8131e87cb83b3ee688984de38a232024b9b44e74ee53d3"}, + {file = "grpcio-1.65.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2ca684ba331fb249d8a1ce88db5394e70dbcd96e58d8c4b7e0d7b141a453dce9"}, + {file = "grpcio-1.65.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8558f0083ddaf5de64a59c790bffd7568e353914c0c551eae2955f54ee4b857f"}, + {file = "grpcio-1.65.1-cp38-cp38-win32.whl", hash = "sha256:8d8143a3e3966f85dce6c5cc45387ec36552174ba5712c5dc6fcc0898fb324c0"}, + {file = "grpcio-1.65.1-cp38-cp38-win_amd64.whl", hash = "sha256:76e81a86424d6ca1ce7c16b15bdd6a964a42b40544bf796a48da241fdaf61153"}, + {file = "grpcio-1.65.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb5175f45c980ff418998723ea1b3869cce3766d2ab4e4916fbd3cedbc9d0ed3"}, + {file = "grpcio-1.65.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b12c1aa7b95abe73b3e04e052c8b362655b41c7798da69f1eaf8d186c7d204df"}, + {file = "grpcio-1.65.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:3019fb50128b21a5e018d89569ffaaaa361680e1346c2f261bb84a91082eb3d3"}, + {file = "grpcio-1.65.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ae15275ed98ea267f64ee9ddedf8ecd5306a5b5bb87972a48bfe24af24153e8"}, + {file = "grpcio-1.65.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f096ffb881f37e8d4f958b63c74bfc400c7cebd7a944b027357cd2fb8d91a57"}, + {file = "grpcio-1.65.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2f56b5a68fdcf17a0a1d524bf177218c3c69b3947cb239ea222c6f1867c3ab68"}, + {file = "grpcio-1.65.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:941596d419b9736ab548aa0feb5bbba922f98872668847bf0720b42d1d227b9e"}, + {file = "grpcio-1.65.1-cp39-cp39-win32.whl", hash = "sha256:5fd7337a823b890215f07d429f4f193d24b80d62a5485cf88ee06648591a0c57"}, + {file = "grpcio-1.65.1-cp39-cp39-win_amd64.whl", hash = "sha256:1bceeec568372cbebf554eae1b436b06c2ff24cfaf04afade729fb9035408c6c"}, + {file = "grpcio-1.65.1.tar.gz", hash = "sha256:3c492301988cd720cd145d84e17318d45af342e29ef93141228f9cd73222368b"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.64.1)"] +protobuf = ["grpcio-tools (>=1.65.1)"] [[package]] name = "grpcio-reflection" @@ -864,13 +866,13 @@ files = [ [[package]] name = "huggingface-hub" -version = "0.23.2" +version = "0.23.5" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, - {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, + {file = "huggingface_hub-0.23.5-py3-none-any.whl", hash = "sha256:d7a7d337615e11a45cc14a0ce5a605db6b038dc24af42866f731684825226e90"}, + {file = "huggingface_hub-0.23.5.tar.gz", hash = "sha256:67a9caba79b71235be3752852ca27da86bd54311d2424ca8afdb8dda056edf98"}, ] [package.dependencies] @@ -992,13 +994,13 @@ files = [ [[package]] name = "jsonschema" -version = "4.22.0" +version = "4.23.0" description = "An implementation of JSON Schema validation for Python" optional = true python-versions = ">=3.8" files = [ - {file = "jsonschema-4.22.0-py3-none-any.whl", hash = "sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802"}, - {file = "jsonschema-4.22.0.tar.gz", hash = "sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7"}, + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, ] [package.dependencies] @@ -1009,7 +1011,7 @@ rpds-py = ">=0.7.1" [package.extras] format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] -format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] [[package]] name = "jsonschema-specifications" @@ -1044,32 +1046,32 @@ regex = ["regex"] [[package]] name = "llvmlite" -version = "0.42.0" +version = "0.43.0" description = "lightweight wrapper around basic LLVM functionality" optional = true python-versions = ">=3.9" files = [ - {file = "llvmlite-0.42.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3366938e1bf63d26c34fbfb4c8e8d2ded57d11e0567d5bb243d89aab1eb56098"}, - {file = "llvmlite-0.42.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c35da49666a21185d21b551fc3caf46a935d54d66969d32d72af109b5e7d2b6f"}, - {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70f44ccc3c6220bd23e0ba698a63ec2a7d3205da0d848804807f37fc243e3f77"}, - {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d"}, - {file = "llvmlite-0.42.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d90edf400b4ceb3a0e776b6c6e4656d05c7187c439587e06f86afceb66d2be5"}, - {file = "llvmlite-0.42.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ae511caed28beaf1252dbaf5f40e663f533b79ceb408c874c01754cafabb9cbf"}, - {file = "llvmlite-0.42.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81e674c2fe85576e6c4474e8c7e7aba7901ac0196e864fe7985492b737dbab65"}, - {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb3975787f13eb97629052edb5017f6c170eebc1c14a0433e8089e5db43bcce6"}, - {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5bece0cdf77f22379f19b1959ccd7aee518afa4afbd3656c6365865f84903f9"}, - {file = "llvmlite-0.42.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e0c4c11c8c2aa9b0701f91b799cb9134a6a6de51444eff5a9087fc7c1384275"}, - {file = "llvmlite-0.42.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:08fa9ab02b0d0179c688a4216b8939138266519aaa0aa94f1195a8542faedb56"}, - {file = "llvmlite-0.42.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b2fce7d355068494d1e42202c7aff25d50c462584233013eb4470c33b995e3ee"}, - {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebe66a86dc44634b59a3bc860c7b20d26d9aaffcd30364ebe8ba79161a9121f4"}, - {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47494552559e00d81bfb836cf1c4d5a5062e54102cc5767d5aa1e77ccd2505c"}, - {file = "llvmlite-0.42.0-cp312-cp312-win_amd64.whl", hash = "sha256:05cb7e9b6ce69165ce4d1b994fbdedca0c62492e537b0cc86141b6e2c78d5888"}, - {file = "llvmlite-0.42.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdd3888544538a94d7ec99e7c62a0cdd8833609c85f0c23fcb6c5c591aec60ad"}, - {file = "llvmlite-0.42.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0936c2067a67fb8816c908d5457d63eba3e2b17e515c5fe00e5ee2bace06040"}, - {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a78ab89f1924fc11482209f6799a7a3fc74ddc80425a7a3e0e8174af0e9e2301"}, - {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7599b65c7af7abbc978dbf345712c60fd596aa5670496561cc10e8a71cebfb2"}, - {file = "llvmlite-0.42.0-cp39-cp39-win_amd64.whl", hash = "sha256:43d65cc4e206c2e902c1004dd5418417c4efa6c1d04df05c6c5675a27e8ca90e"}, - {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, + {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, + {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"}, + {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"}, + {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"}, + {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"}, + {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"}, + {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, ] [[package]] @@ -1295,31 +1297,27 @@ files = [ [[package]] name = "multiprocess" -version = "0.70.15" +version = "0.70.16" description = "better multiprocessing and multithreading in Python" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, - {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, - {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, - {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"}, - {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"}, - {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"}, - {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, ] [package.dependencies] -dill = ">=0.3.7" +dill = ">=0.3.8" [[package]] name = "nest-asyncio" @@ -1352,37 +1350,37 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "numba" -version = "0.59.1" +version = "0.60.0" description = "compiling Python code using LLVM" optional = true python-versions = ">=3.9" files = [ - {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, - {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, - {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, - {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, - {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, - {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, - {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, - {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, - {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, - {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, - {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, - {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, - {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, - {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, - {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, - {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, - {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, - {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, - {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, - {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, - {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, + {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, + {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, + {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"}, + {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"}, + {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"}, + {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"}, + {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"}, + {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"}, + {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"}, + {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"}, + {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"}, + {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"}, + {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"}, + {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"}, + {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"}, + {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"}, + {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"}, + {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"}, + {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"}, + {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"}, + {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"}, ] [package.dependencies] -llvmlite = "==0.42.*" -numpy = ">=1.22,<1.27" +llvmlite = "==0.43.*" +numpy = ">=1.22,<2.1" [[package]] name = "numpy" @@ -1551,13 +1549,14 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.5.40" +version = "12.5.82" description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"}, ] [[package]] @@ -1770,13 +1769,13 @@ test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets" [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -1883,84 +1882,95 @@ test = ["black", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameter [[package]] name = "pillow" -version = "10.3.0" +version = "10.4.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" files = [ - {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"}, - {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"}, - {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"}, - {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"}, - {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"}, - {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"}, - {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"}, - {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"}, - {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"}, - {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"}, - {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"}, - {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"}, - {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"}, - {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"}, - {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, - {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, - {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, - {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, - {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, - {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, - {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, - {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"}, - {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"}, - {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"}, - {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"}, - {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"}, - {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"}, - {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"}, - {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"}, - {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"}, - {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"}, - {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"}, - {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"}, - {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, - {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, + {file = "pillow-10.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e"}, + {file = "pillow-10.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc"}, + {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e"}, + {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46"}, + {file = "pillow-10.4.0-cp310-cp310-win32.whl", hash = "sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984"}, + {file = "pillow-10.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141"}, + {file = "pillow-10.4.0-cp310-cp310-win_arm64.whl", hash = "sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1"}, + {file = "pillow-10.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c"}, + {file = "pillow-10.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319"}, + {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d"}, + {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696"}, + {file = "pillow-10.4.0-cp311-cp311-win32.whl", hash = "sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496"}, + {file = "pillow-10.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91"}, + {file = "pillow-10.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22"}, + {file = "pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94"}, + {file = "pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a"}, + {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b"}, + {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9"}, + {file = "pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42"}, + {file = "pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a"}, + {file = "pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9"}, + {file = "pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3"}, + {file = "pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc"}, + {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a"}, + {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309"}, + {file = "pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060"}, + {file = "pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea"}, + {file = "pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d"}, + {file = "pillow-10.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736"}, + {file = "pillow-10.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd"}, + {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84"}, + {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0"}, + {file = "pillow-10.4.0-cp38-cp38-win32.whl", hash = "sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e"}, + {file = "pillow-10.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab"}, + {file = "pillow-10.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d"}, + {file = "pillow-10.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c"}, + {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1"}, + {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df"}, + {file = "pillow-10.4.0-cp39-cp39-win32.whl", hash = "sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef"}, + {file = "pillow-10.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5"}, + {file = "pillow-10.4.0-cp39-cp39-win_arm64.whl", hash = "sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3"}, + {file = "pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06"}, ] [package.extras] -docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] @@ -2018,27 +2028,28 @@ files = [ [[package]] name = "psutil" -version = "5.9.8" +version = "6.0.0" description = "Cross-platform lib for process and system monitoring in Python." optional = true -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, - {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, - {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, - {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, - {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, - {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, - {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, - {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, - {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, - {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, ] [package.extras] @@ -2057,157 +2068,181 @@ files = [ [[package]] name = "pyarrow" -version = "16.1.0" +version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" files = [ - {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, - {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, - {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, - {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, - {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, - {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, - {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, - {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = true +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] + [[package]] name = "pydantic" -version = "2.7.3" +version = "2.8.2" description = "Data validation using Python type hints" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"}, - {file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"}, + {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, + {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, ] [package.dependencies] annotated-types = ">=0.4.0" -pydantic-core = "2.18.4" -typing-extensions = ">=4.6.1" +pydantic-core = "2.20.1" +typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} [package.extras] email = ["email-validator (>=2.0.0)"] [[package]] name = "pydantic-core" -version = "2.18.4" +version = "2.20.1" description = "Core functionality for Pydantic validation and serialization" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.18.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:f76d0ad001edd426b92233d45c746fd08f467d56100fd8f30e9ace4b005266e4"}, - {file = "pydantic_core-2.18.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:59ff3e89f4eaf14050c8022011862df275b552caef8082e37b542b066ce1ff26"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a55b5b16c839df1070bc113c1f7f94a0af4433fcfa1b41799ce7606e5c79ce0a"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4d0dcc59664fcb8974b356fe0a18a672d6d7cf9f54746c05f43275fc48636851"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8951eee36c57cd128f779e641e21eb40bc5073eb28b2d23f33eb0ef14ffb3f5d"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4701b19f7e3a06ea655513f7938de6f108123bf7c86bbebb1196eb9bd35cf724"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e00a3f196329e08e43d99b79b286d60ce46bed10f2280d25a1718399457e06be"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:97736815b9cc893b2b7f663628e63f436018b75f44854c8027040e05230eeddb"}, - {file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6891a2ae0e8692679c07728819b6e2b822fb30ca7445f67bbf6509b25a96332c"}, - {file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bc4ff9805858bd54d1a20efff925ccd89c9d2e7cf4986144b30802bf78091c3e"}, - {file = "pydantic_core-2.18.4-cp310-none-win32.whl", hash = "sha256:1b4de2e51bbcb61fdebd0ab86ef28062704f62c82bbf4addc4e37fa4b00b7cbc"}, - {file = "pydantic_core-2.18.4-cp310-none-win_amd64.whl", hash = "sha256:6a750aec7bf431517a9fd78cb93c97b9b0c496090fee84a47a0d23668976b4b0"}, - {file = "pydantic_core-2.18.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:942ba11e7dfb66dc70f9ae66b33452f51ac7bb90676da39a7345e99ffb55402d"}, - {file = "pydantic_core-2.18.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b2ebef0e0b4454320274f5e83a41844c63438fdc874ea40a8b5b4ecb7693f1c4"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a642295cd0c8df1b86fc3dced1d067874c353a188dc8e0f744626d49e9aa51c4"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f09baa656c904807e832cf9cce799c6460c450c4ad80803517032da0cd062e2"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98906207f29bc2c459ff64fa007afd10a8c8ac080f7e4d5beff4c97086a3dabd"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19894b95aacfa98e7cb093cd7881a0c76f55731efad31073db4521e2b6ff5b7d"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fbbdc827fe5e42e4d196c746b890b3d72876bdbf160b0eafe9f0334525119c8"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f85d05aa0918283cf29a30b547b4df2fbb56b45b135f9e35b6807cb28bc47951"}, - {file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e85637bc8fe81ddb73fda9e56bab24560bdddfa98aa64f87aaa4e4b6730c23d2"}, - {file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2f5966897e5461f818e136b8451d0551a2e77259eb0f73a837027b47dc95dab9"}, - {file = "pydantic_core-2.18.4-cp311-none-win32.whl", hash = "sha256:44c7486a4228413c317952e9d89598bcdfb06399735e49e0f8df643e1ccd0558"}, - {file = "pydantic_core-2.18.4-cp311-none-win_amd64.whl", hash = "sha256:8a7164fe2005d03c64fd3b85649891cd4953a8de53107940bf272500ba8a788b"}, - {file = "pydantic_core-2.18.4-cp311-none-win_arm64.whl", hash = "sha256:4e99bc050fe65c450344421017f98298a97cefc18c53bb2f7b3531eb39bc7805"}, - {file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"}, - {file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"}, - {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"}, - {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"}, - {file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"}, - {file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"}, - {file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"}, - {file = "pydantic_core-2.18.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:823be1deb01793da05ecb0484d6c9e20baebb39bd42b5d72636ae9cf8350dbd2"}, - {file = "pydantic_core-2.18.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ebef0dd9bf9b812bf75bda96743f2a6c5734a02092ae7f721c048d156d5fabae"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae1d6df168efb88d7d522664693607b80b4080be6750c913eefb77e34c12c71a"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9899c94762343f2cc2fc64c13e7cae4c3cc65cdfc87dd810a31654c9b7358cc"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99457f184ad90235cfe8461c4d70ab7dd2680e28821c29eca00252ba90308c78"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18f469a3d2a2fdafe99296a87e8a4c37748b5080a26b806a707f25a902c040a8"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7cdf28938ac6b8b49ae5e92f2735056a7ba99c9b110a474473fd71185c1af5d"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:938cb21650855054dc54dfd9120a851c974f95450f00683399006aa6e8abb057"}, - {file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:44cd83ab6a51da80fb5adbd9560e26018e2ac7826f9626bc06ca3dc074cd198b"}, - {file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:972658f4a72d02b8abfa2581d92d59f59897d2e9f7e708fdabe922f9087773af"}, - {file = "pydantic_core-2.18.4-cp38-none-win32.whl", hash = "sha256:1d886dc848e60cb7666f771e406acae54ab279b9f1e4143babc9c2258213daa2"}, - {file = "pydantic_core-2.18.4-cp38-none-win_amd64.whl", hash = "sha256:bb4462bd43c2460774914b8525f79b00f8f407c945d50881568f294c1d9b4443"}, - {file = "pydantic_core-2.18.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:44a688331d4a4e2129140a8118479443bd6f1905231138971372fcde37e43528"}, - {file = "pydantic_core-2.18.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a2fdd81edd64342c85ac7cf2753ccae0b79bf2dfa063785503cb85a7d3593223"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86110d7e1907ab36691f80b33eb2da87d780f4739ae773e5fc83fb272f88825f"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46387e38bd641b3ee5ce247563b60c5ca098da9c56c75c157a05eaa0933ed154"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:123c3cec203e3f5ac7b000bd82235f1a3eced8665b63d18be751f115588fea30"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc1803ac5c32ec324c5261c7209e8f8ce88e83254c4e1aebdc8b0a39f9ddb443"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53db086f9f6ab2b4061958d9c276d1dbe3690e8dd727d6abf2321d6cce37fa94"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:abc267fa9837245cc28ea6929f19fa335f3dc330a35d2e45509b6566dc18be23"}, - {file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a0d829524aaefdebccb869eed855e2d04c21d2d7479b6cada7ace5448416597b"}, - {file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:509daade3b8649f80d4e5ff21aa5673e4ebe58590b25fe42fac5f0f52c6f034a"}, - {file = "pydantic_core-2.18.4-cp39-none-win32.whl", hash = "sha256:ca26a1e73c48cfc54c4a76ff78df3727b9d9f4ccc8dbee4ae3f73306a591676d"}, - {file = "pydantic_core-2.18.4-cp39-none-win_amd64.whl", hash = "sha256:c67598100338d5d985db1b3d21f3619ef392e185e71b8d52bceacc4a7771ea7e"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"}, - {file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"}, + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, + {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, + {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, + {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, + {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, + {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, + {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, + {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, + {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, + {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, + {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, + {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, + {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, + {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, ] [package.dependencies] @@ -2446,110 +2481,110 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rpds-py" -version = "0.18.1" +version = "0.19.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = true python-versions = ">=3.8" files = [ - {file = "rpds_py-0.18.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d31dea506d718693b6b2cffc0648a8929bdc51c70a311b2770f09611caa10d53"}, - {file = "rpds_py-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:732672fbc449bab754e0b15356c077cc31566df874964d4801ab14f71951ea80"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a98a1f0552b5f227a3d6422dbd61bc6f30db170939bd87ed14f3c339aa6c7c9"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7f1944ce16401aad1e3f7d312247b3d5de7981f634dc9dfe90da72b87d37887d"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38e14fb4e370885c4ecd734f093a2225ee52dc384b86fa55fe3f74638b2cfb09"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08d74b184f9ab6289b87b19fe6a6d1a97fbfea84b8a3e745e87a5de3029bf944"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d70129cef4a8d979caa37e7fe957202e7eee8ea02c5e16455bc9808a59c6b2f0"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0bb20e3a11bd04461324a6a798af34d503f8d6f1aa3d2aa8901ceaf039176d"}, - {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81c5196a790032e0fc2464c0b4ab95f8610f96f1f2fa3d4deacce6a79852da60"}, - {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f3027be483868c99b4985fda802a57a67fdf30c5d9a50338d9db646d590198da"}, - {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d44607f98caa2961bab4fa3c4309724b185b464cdc3ba6f3d7340bac3ec97cc1"}, - {file = "rpds_py-0.18.1-cp310-none-win32.whl", hash = "sha256:c273e795e7a0f1fddd46e1e3cb8be15634c29ae8ff31c196debb620e1edb9333"}, - {file = "rpds_py-0.18.1-cp310-none-win_amd64.whl", hash = "sha256:8352f48d511de5f973e4f2f9412736d7dea76c69faa6d36bcf885b50c758ab9a"}, - {file = "rpds_py-0.18.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b5ff7e1d63a8281654b5e2896d7f08799378e594f09cf3674e832ecaf396ce8"}, - {file = "rpds_py-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8927638a4d4137a289e41d0fd631551e89fa346d6dbcfc31ad627557d03ceb6d"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:154bf5c93d79558b44e5b50cc354aa0459e518e83677791e6adb0b039b7aa6a7"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07f2139741e5deb2c5154a7b9629bc5aa48c766b643c1a6750d16f865a82c5fc"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c7672e9fba7425f79019db9945b16e308ed8bc89348c23d955c8c0540da0a07"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:489bdfe1abd0406eba6b3bb4fdc87c7fa40f1031de073d0cfb744634cc8fa261"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c20f05e8e3d4fc76875fc9cb8cf24b90a63f5a1b4c5b9273f0e8225e169b100"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:967342e045564cef76dfcf1edb700b1e20838d83b1aa02ab313e6a497cf923b8"}, - {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2cc7c1a47f3a63282ab0f422d90ddac4aa3034e39fc66a559ab93041e6505da7"}, - {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f7afbfee1157e0f9376c00bb232e80a60e59ed716e3211a80cb8506550671e6e"}, - {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e6934d70dc50f9f8ea47081ceafdec09245fd9f6032669c3b45705dea096b88"}, - {file = "rpds_py-0.18.1-cp311-none-win32.whl", hash = "sha256:c69882964516dc143083d3795cb508e806b09fc3800fd0d4cddc1df6c36e76bb"}, - {file = "rpds_py-0.18.1-cp311-none-win_amd64.whl", hash = "sha256:70a838f7754483bcdc830444952fd89645569e7452e3226de4a613a4c1793fb2"}, - {file = "rpds_py-0.18.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3dd3cd86e1db5aadd334e011eba4e29d37a104b403e8ca24dcd6703c68ca55b3"}, - {file = "rpds_py-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:05f3d615099bd9b13ecf2fc9cf2d839ad3f20239c678f461c753e93755d629ee"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b2b771b13eee8729a5049c976197ff58a27a3829c018a04341bcf1ae409b2b"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee17cd26b97d537af8f33635ef38be873073d516fd425e80559f4585a7b90c43"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b646bf655b135ccf4522ed43d6902af37d3f5dbcf0da66c769a2b3938b9d8184"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19ba472b9606c36716062c023afa2484d1e4220548751bda14f725a7de17b4f6"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e30ac5e329098903262dc5bdd7e2086e0256aa762cc8b744f9e7bf2a427d3f8"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d58ad6317d188c43750cb76e9deacf6051d0f884d87dc6518e0280438648a9ac"}, - {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e1735502458621921cee039c47318cb90b51d532c2766593be6207eec53e5c4c"}, - {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f5bab211605d91db0e2995a17b5c6ee5edec1270e46223e513eaa20da20076ac"}, - {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2fc24a329a717f9e2448f8cd1f960f9dac4e45b6224d60734edeb67499bab03a"}, - {file = "rpds_py-0.18.1-cp312-none-win32.whl", hash = "sha256:1805d5901779662d599d0e2e4159d8a82c0b05faa86ef9222bf974572286b2b6"}, - {file = "rpds_py-0.18.1-cp312-none-win_amd64.whl", hash = "sha256:720edcb916df872d80f80a1cc5ea9058300b97721efda8651efcd938a9c70a72"}, - {file = "rpds_py-0.18.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:c827576e2fa017a081346dce87d532a5310241648eb3700af9a571a6e9fc7e74"}, - {file = "rpds_py-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:aa3679e751408d75a0b4d8d26d6647b6d9326f5e35c00a7ccd82b78ef64f65f8"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0abeee75434e2ee2d142d650d1e54ac1f8b01e6e6abdde8ffd6eeac6e9c38e20"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed402d6153c5d519a0faf1bb69898e97fb31613b49da27a84a13935ea9164dfc"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:338dee44b0cef8b70fd2ef54b4e09bb1b97fc6c3a58fea5db6cc083fd9fc2724"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7750569d9526199c5b97e5a9f8d96a13300950d910cf04a861d96f4273d5b104"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:607345bd5912aacc0c5a63d45a1f73fef29e697884f7e861094e443187c02be5"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:207c82978115baa1fd8d706d720b4a4d2b0913df1c78c85ba73fe6c5804505f0"}, - {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6d1e42d2735d437e7e80bab4d78eb2e459af48c0a46e686ea35f690b93db792d"}, - {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5463c47c08630007dc0fe99fb480ea4f34a89712410592380425a9b4e1611d8e"}, - {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:06d218939e1bf2ca50e6b0ec700ffe755e5216a8230ab3e87c059ebb4ea06afc"}, - {file = "rpds_py-0.18.1-cp38-none-win32.whl", hash = "sha256:312fe69b4fe1ffbe76520a7676b1e5ac06ddf7826d764cc10265c3b53f96dbe9"}, - {file = "rpds_py-0.18.1-cp38-none-win_amd64.whl", hash = "sha256:9437ca26784120a279f3137ee080b0e717012c42921eb07861b412340f85bae2"}, - {file = "rpds_py-0.18.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:19e515b78c3fc1039dd7da0a33c28c3154458f947f4dc198d3c72db2b6b5dc93"}, - {file = "rpds_py-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7b28c5b066bca9a4eb4e2f2663012debe680f097979d880657f00e1c30875a0"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:673fdbbf668dd958eff750e500495ef3f611e2ecc209464f661bc82e9838991e"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d960de62227635d2e61068f42a6cb6aae91a7fe00fca0e3aeed17667c8a34611"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352a88dc7892f1da66b6027af06a2e7e5d53fe05924cc2cfc56495b586a10b72"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e0ee01ad8260184db21468a6e1c37afa0529acc12c3a697ee498d3c2c4dcaf3"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4c39ad2f512b4041343ea3c7894339e4ca7839ac38ca83d68a832fc8b3748ab"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aaa71ee43a703c321906813bb252f69524f02aa05bf4eec85f0c41d5d62d0f4c"}, - {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6cd8098517c64a85e790657e7b1e509b9fe07487fd358e19431cb120f7d96338"}, - {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4adec039b8e2928983f885c53b7cc4cda8965b62b6596501a0308d2703f8af1b"}, - {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32b7daaa3e9389db3695964ce8e566e3413b0c43e3394c05e4b243a4cd7bef26"}, - {file = "rpds_py-0.18.1-cp39-none-win32.whl", hash = "sha256:2625f03b105328729f9450c8badda34d5243231eef6535f80064d57035738360"}, - {file = "rpds_py-0.18.1-cp39-none-win_amd64.whl", hash = "sha256:bf18932d0003c8c4d51a39f244231986ab23ee057d235a12b2684ea26a353590"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cbfbea39ba64f5e53ae2915de36f130588bba71245b418060ec3330ebf85678e"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a3d456ff2a6a4d2adcdf3c1c960a36f4fd2fec6e3b4902a42a384d17cf4e7a65"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7700936ef9d006b7ef605dc53aa364da2de5a3aa65516a1f3ce73bf82ecfc7ae"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51584acc5916212e1bf45edd17f3a6b05fe0cbb40482d25e619f824dccb679de"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:942695a206a58d2575033ff1e42b12b2aece98d6003c6bc739fbf33d1773b12f"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b906b5f58892813e5ba5c6056d6a5ad08f358ba49f046d910ad992196ea61397"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6f8e3fecca256fefc91bb6765a693d96692459d7d4c644660a9fff32e517843"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7732770412bab81c5a9f6d20aeb60ae943a9b36dcd990d876a773526468e7163"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bd1105b50ede37461c1d51b9698c4f4be6e13e69a908ab7751e3807985fc0346"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:618916f5535784960f3ecf8111581f4ad31d347c3de66d02e728de460a46303c"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:17c6d2155e2423f7e79e3bb18151c686d40db42d8645e7977442170c360194d4"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c4c4c3f878df21faf5fac86eda32671c27889e13570645a9eea0a1abdd50922"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:fab6ce90574645a0d6c58890e9bcaac8d94dff54fb51c69e5522a7358b80ab64"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:531796fb842b53f2695e94dc338929e9f9dbf473b64710c28af5a160b2a8927d"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:740884bc62a5e2bbb31e584f5d23b32320fd75d79f916f15a788d527a5e83644"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:998125738de0158f088aef3cb264a34251908dd2e5d9966774fdab7402edfab7"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2be6e9dd4111d5b31ba3b74d17da54a8319d8168890fbaea4b9e5c3de630ae5"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0cee71bc618cd93716f3c1bf56653740d2d13ddbd47673efa8bf41435a60daa"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2c3caec4ec5cd1d18e5dd6ae5194d24ed12785212a90b37f5f7f06b8bedd7139"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:27bba383e8c5231cd559affe169ca0b96ec78d39909ffd817f28b166d7ddd4d8"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:a888e8bdb45916234b99da2d859566f1e8a1d2275a801bb8e4a9644e3c7e7909"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6031b25fb1b06327b43d841f33842b383beba399884f8228a6bb3df3088485ff"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48c2faaa8adfacefcbfdb5f2e2e7bdad081e5ace8d182e5f4ade971f128e6bb3"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:d85164315bd68c0806768dc6bb0429c6f95c354f87485ee3593c4f6b14def2bd"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6afd80f6c79893cfc0574956f78a0add8c76e3696f2d6a15bca2c66c415cf2d4"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa242ac1ff583e4ec7771141606aafc92b361cd90a05c30d93e343a0c2d82a89"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21be4770ff4e08698e1e8e0bce06edb6ea0626e7c8f560bc08222880aca6a6f"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c45a639e93a0c5d4b788b2613bd637468edd62f8f95ebc6fcc303d58ab3f0a8"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910e71711d1055b2768181efa0a17537b2622afeb0424116619817007f8a2b10"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9bb1f182a97880f6078283b3505a707057c42bf55d8fca604f70dedfdc0772a"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d54f74f40b1f7aaa595a02ff42ef38ca654b1469bef7d52867da474243cc633"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8d2e182c9ee01135e11e9676e9a62dfad791a7a467738f06726872374a83db49"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:636a15acc588f70fda1661234761f9ed9ad79ebed3f2125d44be0862708b666e"}, - {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, + {file = "rpds_py-0.19.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:fb37bd599f031f1a6fb9e58ec62864ccf3ad549cf14bac527dbfa97123edcca4"}, + {file = "rpds_py-0.19.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3384d278df99ec2c6acf701d067147320b864ef6727405d6470838476e44d9e8"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e54548e0be3ac117595408fd4ca0ac9278fde89829b0b518be92863b17ff67a2"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8eb488ef928cdbc05a27245e52de73c0d7c72a34240ef4d9893fdf65a8c1a955"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5da93debdfe27b2bfc69eefb592e1831d957b9535e0943a0ee8b97996de21b5"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79e205c70afddd41f6ee79a8656aec738492a550247a7af697d5bd1aee14f766"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:959179efb3e4a27610e8d54d667c02a9feaa86bbabaf63efa7faa4dfa780d4f1"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a6e605bb9edcf010f54f8b6a590dd23a4b40a8cb141255eec2a03db249bc915b"}, + {file = "rpds_py-0.19.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9133d75dc119a61d1a0ded38fb9ba40a00ef41697cc07adb6ae098c875195a3f"}, + {file = "rpds_py-0.19.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dd36b712d35e757e28bf2f40a71e8f8a2d43c8b026d881aa0c617b450d6865c9"}, + {file = "rpds_py-0.19.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:354f3a91718489912f2e0fc331c24eaaf6a4565c080e00fbedb6015857c00582"}, + {file = "rpds_py-0.19.0-cp310-none-win32.whl", hash = "sha256:ebcbf356bf5c51afc3290e491d3722b26aaf5b6af3c1c7f6a1b757828a46e336"}, + {file = "rpds_py-0.19.0-cp310-none-win_amd64.whl", hash = "sha256:75a6076289b2df6c8ecb9d13ff79ae0cad1d5fb40af377a5021016d58cd691ec"}, + {file = "rpds_py-0.19.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6d45080095e585f8c5097897313def60caa2046da202cdb17a01f147fb263b81"}, + {file = "rpds_py-0.19.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c5c9581019c96f865483d031691a5ff1cc455feb4d84fc6920a5ffc48a794d8a"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1540d807364c84516417115c38f0119dfec5ea5c0dd9a25332dea60b1d26fc4d"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9e65489222b410f79711dc3d2d5003d2757e30874096b2008d50329ea4d0f88c"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9da6f400eeb8c36f72ef6646ea530d6d175a4f77ff2ed8dfd6352842274c1d8b"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37f46bb11858717e0efa7893c0f7055c43b44c103e40e69442db5061cb26ed34"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:071d4adc734de562bd11d43bd134330fb6249769b2f66b9310dab7460f4bf714"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9625367c8955e4319049113ea4f8fee0c6c1145192d57946c6ffcd8fe8bf48dd"}, + {file = "rpds_py-0.19.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e19509145275d46bc4d1e16af0b57a12d227c8253655a46bbd5ec317e941279d"}, + {file = "rpds_py-0.19.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d438e4c020d8c39961deaf58f6913b1bf8832d9b6f62ec35bd93e97807e9cbc"}, + {file = "rpds_py-0.19.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:90bf55d9d139e5d127193170f38c584ed3c79e16638890d2e36f23aa1630b952"}, + {file = "rpds_py-0.19.0-cp311-none-win32.whl", hash = "sha256:8d6ad132b1bc13d05ffe5b85e7a01a3998bf3a6302ba594b28d61b8c2cf13aaf"}, + {file = "rpds_py-0.19.0-cp311-none-win_amd64.whl", hash = "sha256:7ec72df7354e6b7f6eb2a17fa6901350018c3a9ad78e48d7b2b54d0412539a67"}, + {file = "rpds_py-0.19.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:5095a7c838a8647c32aa37c3a460d2c48debff7fc26e1136aee60100a8cd8f68"}, + {file = "rpds_py-0.19.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f2f78ef14077e08856e788fa482107aa602636c16c25bdf59c22ea525a785e9"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7cc6cb44f8636fbf4a934ca72f3e786ba3c9f9ba4f4d74611e7da80684e48d2"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf902878b4af334a09de7a45badbff0389e7cf8dc2e4dcf5f07125d0b7c2656d"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:688aa6b8aa724db1596514751ffb767766e02e5c4a87486ab36b8e1ebc1aedac"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57dbc9167d48e355e2569346b5aa4077f29bf86389c924df25c0a8b9124461fb"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b4cf5a9497874822341c2ebe0d5850fed392034caadc0bad134ab6822c0925b"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8a790d235b9d39c70a466200d506bb33a98e2ee374a9b4eec7a8ac64c2c261fa"}, + {file = "rpds_py-0.19.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1d16089dfa58719c98a1c06f2daceba6d8e3fb9b5d7931af4a990a3c486241cb"}, + {file = "rpds_py-0.19.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bc9128e74fe94650367fe23f37074f121b9f796cabbd2f928f13e9661837296d"}, + {file = "rpds_py-0.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c8f77e661ffd96ff104bebf7d0f3255b02aa5d5b28326f5408d6284c4a8b3248"}, + {file = "rpds_py-0.19.0-cp312-none-win32.whl", hash = "sha256:5f83689a38e76969327e9b682be5521d87a0c9e5a2e187d2bc6be4765f0d4600"}, + {file = "rpds_py-0.19.0-cp312-none-win_amd64.whl", hash = "sha256:06925c50f86da0596b9c3c64c3837b2481337b83ef3519e5db2701df695453a4"}, + {file = "rpds_py-0.19.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:52e466bea6f8f3a44b1234570244b1cff45150f59a4acae3fcc5fd700c2993ca"}, + {file = "rpds_py-0.19.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e21cc693045fda7f745c790cb687958161ce172ffe3c5719ca1764e752237d16"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b31f059878eb1f5da8b2fd82480cc18bed8dcd7fb8fe68370e2e6285fa86da6"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dd46f309e953927dd018567d6a9e2fb84783963650171f6c5fe7e5c41fd5666"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:34a01a4490e170376cd79258b7f755fa13b1a6c3667e872c8e35051ae857a92b"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bcf426a8c38eb57f7bf28932e68425ba86def6e756a5b8cb4731d8e62e4e0223"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68eea5df6347d3f1378ce992d86b2af16ad7ff4dcb4a19ccdc23dea901b87fb"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dab8d921b55a28287733263c0e4c7db11b3ee22aee158a4de09f13c93283c62d"}, + {file = "rpds_py-0.19.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6fe87efd7f47266dfc42fe76dae89060038f1d9cb911f89ae7e5084148d1cc08"}, + {file = "rpds_py-0.19.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:535d4b52524a961d220875688159277f0e9eeeda0ac45e766092bfb54437543f"}, + {file = "rpds_py-0.19.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:8b1a94b8afc154fbe36978a511a1f155f9bd97664e4f1f7a374d72e180ceb0ae"}, + {file = "rpds_py-0.19.0-cp38-none-win32.whl", hash = "sha256:7c98298a15d6b90c8f6e3caa6457f4f022423caa5fa1a1ca7a5e9e512bdb77a4"}, + {file = "rpds_py-0.19.0-cp38-none-win_amd64.whl", hash = "sha256:b0da31853ab6e58a11db3205729133ce0df26e6804e93079dee095be3d681dc1"}, + {file = "rpds_py-0.19.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5039e3cef7b3e7a060de468a4a60a60a1f31786da94c6cb054e7a3c75906111c"}, + {file = "rpds_py-0.19.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab1932ca6cb8c7499a4d87cb21ccc0d3326f172cfb6a64021a889b591bb3045c"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2afd2164a1e85226fcb6a1da77a5c8896c18bfe08e82e8ceced5181c42d2179"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b1c30841f5040de47a0046c243fc1b44ddc87d1b12435a43b8edff7e7cb1e0d0"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f757f359f30ec7dcebca662a6bd46d1098f8b9fb1fcd661a9e13f2e8ce343ba1"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15e65395a59d2e0e96caf8ee5389ffb4604e980479c32742936ddd7ade914b22"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb0f6eb3a320f24b94d177e62f4074ff438f2ad9d27e75a46221904ef21a7b05"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b228e693a2559888790936e20f5f88b6e9f8162c681830eda303bad7517b4d5a"}, + {file = "rpds_py-0.19.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2575efaa5d949c9f4e2cdbe7d805d02122c16065bfb8d95c129372d65a291a0b"}, + {file = "rpds_py-0.19.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:5c872814b77a4e84afa293a1bee08c14daed1068b2bb1cc312edbf020bbbca2b"}, + {file = "rpds_py-0.19.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:850720e1b383df199b8433a20e02b25b72f0fded28bc03c5bd79e2ce7ef050be"}, + {file = "rpds_py-0.19.0-cp39-none-win32.whl", hash = "sha256:ce84a7efa5af9f54c0aa7692c45861c1667080814286cacb9958c07fc50294fb"}, + {file = "rpds_py-0.19.0-cp39-none-win_amd64.whl", hash = "sha256:1c26da90b8d06227d7769f34915913911222d24ce08c0ab2d60b354e2d9c7aff"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:75969cf900d7be665ccb1622a9aba225cf386bbc9c3bcfeeab9f62b5048f4a07"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8445f23f13339da640d1be8e44e5baf4af97e396882ebbf1692aecd67f67c479"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5a7c1062ef8aea3eda149f08120f10795835fc1c8bc6ad948fb9652a113ca55"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:462b0c18fbb48fdbf980914a02ee38c423a25fcc4cf40f66bacc95a2d2d73bc8"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3208f9aea18991ac7f2b39721e947bbd752a1abbe79ad90d9b6a84a74d44409b"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3444fe52b82f122d8a99bf66777aed6b858d392b12f4c317da19f8234db4533"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88cb4bac7185a9f0168d38c01d7a00addece9822a52870eee26b8d5b61409213"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6b130bd4163c93798a6b9bb96be64a7c43e1cec81126ffa7ffaa106e1fc5cef5"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:a707b158b4410aefb6b054715545bbb21aaa5d5d0080217290131c49c2124a6e"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dc9ac4659456bde7c567107556ab065801622396b435a3ff213daef27b495388"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:81ea573aa46d3b6b3d890cd3c0ad82105985e6058a4baed03cf92518081eec8c"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f148c3f47f7f29a79c38cc5d020edcb5ca780020fab94dbc21f9af95c463581"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0906357f90784a66e89ae3eadc2654f36c580a7d65cf63e6a616e4aec3a81be"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f629ecc2db6a4736b5ba95a8347b0089240d69ad14ac364f557d52ad68cf94b0"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c6feacd1d178c30e5bc37184526e56740342fd2aa6371a28367bad7908d454fc"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae8b6068ee374fdfab63689be0963333aa83b0815ead5d8648389a8ded593378"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78d57546bad81e0da13263e4c9ce30e96dcbe720dbff5ada08d2600a3502e526"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8b6683a37338818646af718c9ca2a07f89787551057fae57c4ec0446dc6224b"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e8481b946792415adc07410420d6fc65a352b45d347b78fec45d8f8f0d7496f0"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bec35eb20792ea64c3c57891bc3ca0bedb2884fbac2c8249d9b731447ecde4fa"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:aa5476c3e3a402c37779e95f7b4048db2cb5b0ed0b9d006983965e93f40fe05a"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:19d02c45f2507b489fd4df7b827940f1420480b3e2e471e952af4d44a1ea8e34"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a3e2fd14c5d49ee1da322672375963f19f32b3d5953f0615b175ff7b9d38daed"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:93a91c2640645303e874eada51f4f33351b84b351a689d470f8108d0e0694210"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5b9fc03bf76a94065299d4a2ecd8dfbae4ae8e2e8098bbfa6ab6413ca267709"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a4b07cdf3f84310c08c1de2c12ddadbb7a77568bcb16e95489f9c81074322ed"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba0ed0dc6763d8bd6e5de5cf0d746d28e706a10b615ea382ac0ab17bb7388633"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:474bc83233abdcf2124ed3f66230a1c8435896046caa4b0b5ab6013c640803cc"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:329c719d31362355a96b435f4653e3b4b061fcc9eba9f91dd40804ca637d914e"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef9101f3f7b59043a34f1dccbb385ca760467590951952d6701df0da9893ca0c"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:0121803b0f424ee2109d6e1f27db45b166ebaa4b32ff47d6aa225642636cd834"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8344127403dea42f5970adccf6c5957a71a47f522171fafaf4c6ddb41b61703a"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:443cec402ddd650bb2b885113e1dcedb22b1175c6be223b14246a714b61cd521"}, + {file = "rpds_py-0.19.0.tar.gz", hash = "sha256:4fdc9afadbeb393b4bbbad75481e0ea78e4469f2e1d713a90811700830b553a9"}, ] [[package]] @@ -2772,18 +2807,19 @@ files = [ [[package]] name = "setuptools" -version = "70.0.0" +version = "71.1.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-71.1.0-py3-none-any.whl", hash = "sha256:33874fdc59b3188304b2e7c80d9029097ea31627180896fb549c578ceb8a0855"}, + {file = "setuptools-71.1.0.tar.gz", hash = "sha256:032d42ee9fb536e33087fb66cac5f840eb9391ed05637b3f2a76a7c8fb477936"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -2798,29 +2834,32 @@ files = [ [[package]] name = "sympy" -version = "1.12.1" +version = "1.13.1" description = "Computer algebra system (CAS) in Python" optional = true python-versions = ">=3.8" files = [ - {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, - {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, + {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, + {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, ] [package.dependencies] -mpmath = ">=1.1.0,<1.4.0" +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "tbb" -version = "2021.12.0" +version = "2021.13.0" description = "Intel® oneAPI Threading Building Blocks (oneTBB)" optional = true python-versions = "*" files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, + {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, + {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, ] [[package]] @@ -2964,31 +3003,31 @@ files = [ [[package]] name = "torch" -version = "2.3.0" +version = "2.3.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = true python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, - {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, - {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, - {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, - {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, - {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, - {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, - {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, - {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, - {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, - {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, - {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, - {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, - {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, - {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, - {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, - {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, - {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, - {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, - {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, + {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, + {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, + {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, + {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, + {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, + {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, + {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, + {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, + {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, + {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, + {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, + {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, + {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, + {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, + {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, + {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, + {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, + {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, + {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, + {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, ] [package.dependencies] @@ -3009,7 +3048,7 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] @@ -3038,19 +3077,19 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.41.2" +version = "4.42.4" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, - {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, + {file = "transformers-4.42.4-py3-none-any.whl", hash = "sha256:6d59061392d0f1da312af29c962df9017ff3c0108c681a56d1bc981004d16d24"}, + {file = "transformers-4.42.4.tar.gz", hash = "sha256:f956e25e24df851f650cb2c158b6f4352dfae9d702f04c113ed24fc36ce7ae2d"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.23.0,<1.0" -numpy = ">=1.17" +huggingface-hub = ">=0.23.2,<1.0" +numpy = ">=1.17,<2.0" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" @@ -3062,14 +3101,15 @@ tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.2.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] @@ -3080,41 +3120,42 @@ natten = ["natten (>=0.14.6,<0.15.0)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.4.4)", "urllib3 (<2.0.0)"] ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.4.4)"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -timm = ["timm"] +timm = ["timm (<=0.9.16)"] tokenizers = ["tokenizers (>=0.19,<0.20)"] torch = ["accelerate (>=0.21.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17,<2.0)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.3.0" +version = "2.3.1" description = "A language and compiler for custom Deep Learning operations" optional = true python-versions = "*" files = [ - {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, - {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, - {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, - {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, - {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, - {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, + {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, + {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, + {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, + {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, + {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, + {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, ] [package.dependencies] @@ -3147,13 +3188,13 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6. [[package]] name = "typing-extensions" -version = "4.12.1" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, - {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -3169,13 +3210,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] @@ -3499,18 +3540,18 @@ multidict = ">=4.0" [[package]] name = "zipp" -version = "3.19.1" +version = "3.19.2" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, - {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, + {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, + {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, ] [package.extras] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] accelerate = ["accelerate"] @@ -3523,4 +3564,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "f62a7a74e1e1bcb3b7cb4f7da2b538065830748062a2b57fdbb4c76eae5abddc" +content-hash = "877137815e7080bf0f021c1a6c137b2d02042fc9340c862347aa07fde53b5f93" diff --git a/server/pyproject.toml b/server/pyproject.toml index 4cfbc1a2..0e9976ec 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.19.1" huggingface-hub = "^0.23" -transformers = "^4.41" +transformers = "^4.42" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 88fcc4f3..d074e406 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -1,48 +1,50 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 5751bf81..d074e406 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -1,48 +1,50 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" +setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 88fcc4f3..d074e406 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -1,48 +1,50 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 94b69899..8ff51596 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -2,6 +2,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE from text_generation_server.layers.attention import Seqlen +from typing import Optional major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -43,6 +44,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -82,6 +84,8 @@ def paged_attention( # by the current path # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + if softcap is None: + softcap = 0.0 out2 = flash_attn_2_cuda.varlen_fwd( query, key_cache, @@ -89,6 +93,7 @@ def paged_attention( None, seqlen.cu_seqlen_q, seqlen.cu_seqlen_k, + None, # pad_k None, block_tables, None, @@ -100,11 +105,14 @@ def paged_attention( True, # causal -1, # Window_left -1, # Window right + softcap, False, # return softmax None, # generator ) return out2[0] else: + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths from vllm._C import ops @@ -205,6 +213,7 @@ if V2: softmax_scale, window_size_left=-1, causal=True, + softcap=0.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -218,6 +227,7 @@ if V2: None, None, None, + None, max_s, max_s, 0.0, @@ -226,6 +236,7 @@ if V2: causal, window_size_left, 0, + softcap, False, None, ) @@ -241,11 +252,14 @@ else: max_s, softmax_scale, window_size_left=-1, + softcap=None, ): if window_size_left != -1: raise NotImplementedError( "window_size_left is only available with flash attn v2" ) + if softcap is not None: + raise NotImplementedError("softcap is only available with flash attn v2") # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1cd13a2a..f6ffb4aa 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -762,6 +762,8 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + # hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it` + head_size=config_dict["head_dim"], ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 8526d515..0dc5b9cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -189,6 +189,7 @@ class FlashGemma2Attention(torch.nn.Module): self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) + self.softcap = config.attn_logit_softcapping self.query_key_value = load_attention(config, prefix, weights) @@ -246,6 +247,7 @@ class FlashGemma2Attention(torch.nn.Module): self.softmax_scale, causal=self.causal, window_size_left=self.window_size, + softcap=self.softcap, ) # Decode else: @@ -259,6 +261,7 @@ class FlashGemma2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + softcap=self.softcap, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -466,6 +469,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module): config=config, weights=weights, ) + self.softcap = config.final_logit_softcapping + assert isinstance(self.softcap, float) def forward( self, @@ -495,4 +500,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) + + logits /= self.softcap + logits = torch.tanh(logits) + logits *= self.softcap + return logits, speculative_logits From 31eb03dbe2d14fccd10438208664c114b8126734 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Jul 2024 11:16:03 +0200 Subject: [PATCH 152/340] Fixing mistral nemo. (#2276) --- server/text_generation_server/models/__init__.py | 2 -- .../models/custom_modeling/flash_mistral_modeling.py | 5 ++++- server/text_generation_server/models/flash_causal_lm.py | 7 ++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f6ffb4aa..1cd13a2a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -762,8 +762,6 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, - # hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it` - head_size=config_dict["head_dim"], ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8028dbe8..eeb3c45f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -117,7 +117,10 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_heads + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( config=config, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cfffafa1..5db62431 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -925,7 +925,12 @@ class FlashCausalLM(Model): assert self.num_kv_heads > 0 if head_size is None: - self.head_size = config.hidden_size // config.num_attention_heads + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = config.hidden_size // config.num_attention_heads else: self.head_size = head_size From 919da25c3b27f39c7215b44f80fcce1dff28e33b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 23 Jul 2024 09:24:29 +0000 Subject: [PATCH 153/340] fix(l4): fix fp8 logic on l4 (#2277) * fix(l4): fix fp8 logic on l4 * also quant weights with single scale * use marlin even on 89 --- server/text_generation_server/layers/fp8.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index bf5a0989..54550ee3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module: """ if SYSTEM == "cuda": - major, minor = torch.cuda.get_device_capability() - if major == 8 and minor < 9: + major, _ = torch.cuda.get_device_capability() + if major == 8: from text_generation_server.layers.marlin import GPTQMarlinFP8Linear return GPTQMarlinFP8Linear @@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): - if FBGEMM_DYN_AVAILABLE: +def fp8_quantize( + weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False +): + if FBGEMM_DYN_AVAILABLE and not scalar: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype ) @@ -186,6 +188,9 @@ class Fp8Linear(torch.nn.Module): dtype, ) -> None: super().__init__() + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + self.dtype = dtype self.qweight = qweight self.scale = scale @@ -201,7 +206,7 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight) + qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) return cls( qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype ) @@ -232,7 +237,7 @@ class Fp8Linear(torch.nn.Module): ) return y.to(self.dtype) - qinput, scale = fp8_quantize(input) + qinput, scale = fp8_quantize(input, scalar=True) output, _ = torch._scaled_mm( qinput, self.qweight.t(), From 26460f053dc19e1134079913541f323f48bab205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 23 Jul 2024 13:08:20 +0200 Subject: [PATCH 154/340] Add support for repacking AWQ weights for GPTQ-Marlin (#2278) * Add support for repacking AWQ weights for GPTQ-Marlin So far we couldn't support AWQ because virtually all AWQ models use symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin has recently added support AWQ repacking and AWQ asymmetric quantization (zero_point=True). This change updates all GPTQ-Marlin kernels from upstream and wires up AWQ support. For now enabling AWQ using Marlin requires running TGI with `--quantize gptq`. * Enable Marlin for supported AWQ configurations by default This makes the AWQ -> GPTQ repack test redundant, since we are now testing this with the regular AWQ test. --- .../marlin_kernels/awq_marlin_repack.cu | 269 +++++++++ server/marlin/marlin_kernels/ext.cpp | 2 + server/marlin/marlin_kernels/ext.hh | 16 +- server/marlin/marlin_kernels/fp8_marlin.cu | 33 +- server/marlin/marlin_kernels/gptq_marlin.cu | 525 ++++++++++++++---- .../marlin_kernels/gptq_marlin_repack.cu | 58 +- .../{gptq_marlin.cuh => marlin.cuh} | 15 +- .../marlin_kernels/marlin_cuda_kernel.cu | 46 +- ...tq_marlin_dtypes.cuh => marlin_dtypes.cuh} | 8 +- server/marlin/setup.py | 1 + .../layers/gptq/__init__.py | 63 ++- .../text_generation_server/layers/marlin.py | 166 +++++- .../utils/quantization.py | 17 +- 13 files changed, 1016 insertions(+), 203 deletions(-) create mode 100644 server/marlin/marlin_kernels/awq_marlin_repack.cu rename server/marlin/marlin_kernels/{gptq_marlin.cuh => marlin.cuh} (88%) rename server/marlin/marlin_kernels/{gptq_marlin_dtypes.cuh => marlin_dtypes.cuh} (93%) diff --git a/server/marlin/marlin_kernels/awq_marlin_repack.cu b/server/marlin/marlin_kernels/awq_marlin_repack.cu new file mode 100644 index 00000000..c58216d8 --- /dev/null +++ b/server/marlin/marlin_kernels/awq_marlin_repack.cu @@ -0,0 +1,269 @@ +#include "marlin.cuh" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +namespace marlin { + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} + +} // namespace marlin + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +namespace marlin { + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4* sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace marlin + + #define CALL_IF(NUM_BITS) \ + else if (num_bits == NUM_BITS) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", marlin::tile_k_size); + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", marlin::tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK(b_q_weight.size(0) == size_k, + "b_q_weight.size(0) = ", b_q_weight.size(0), + " is not size_k = ", size_k); + TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_n = ", size_n, ", pack_factor = ", pack_factor); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = torch::empty( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); + + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4) + CALL_IF(8) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + } + + return out; +} + +#endif diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 04e1530f..8dcd2ff9 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -3,6 +3,8 @@ #include "ext.hh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("awq_marlin_repack", &awq_marlin_repack, + "Repack AWQ parameters for Marlin"); m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Marlin gemm with GPTQ compatibility"); m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index 102c058e..ec76b5f4 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -6,11 +6,15 @@ // No support for async #else +torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits); + torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full); + torch::Tensor &b_scales, torch::Tensor &b_zeros, + torch::Tensor &g_idx, torch::Tensor &perm, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp); torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_meta, @@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &workspace, int64_t size_m, int64_t size_n, int64_t size_k); -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, +torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k); diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu index aaef67e5..f6458ca6 100644 --- a/server/marlin/marlin_kernels/fp8_marlin.cu +++ b/server/marlin/marlin_kernels/fp8_marlin.cu @@ -19,10 +19,10 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "./gptq_marlin.cuh" -#include "./gptq_marlin_dtypes.cuh" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" -using namespace gptq_marlin; +using namespace marlin; #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", marlin::tile_size); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + " is not divisible by tile_size = ", marlin::tile_size); + int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, num_groups = b_scales.size(0); // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); @@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { fp8_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + marlin::max_par); } else { TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); } diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu index 0beb9de1..122c5c16 100644 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ b/server/marlin/marlin_kernels/gptq_marlin.cu @@ -19,8 +19,8 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "gptq_marlin.cuh" -#include "gptq_marlin_dtypes.cuh" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -32,7 +32,7 @@ inline std::string str(T x) { return std::to_string(x); } -namespace gptq_marlin { +namespace marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -72,10 +72,11 @@ __global__ void Marlin( } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -264,6 +265,114 @@ dequant_8bit(int q) { return frag_b; } +// Zero-point dequantizers + +template +__device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit_zp( + int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_4bit_zp(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template +__device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit_zp( + int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit_zp(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template @@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType::FragB& frag_b, frag_b[1] = __hmul2(frag_b[1], s); } +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + // Same as above, but for act_order (each K is multiplied individually) template __device__ inline void scale4(typename ScalarType::FragB& frag_b, @@ -404,6 +524,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -413,6 +534,8 @@ __global__ void Marlin( int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m @@ -437,6 +560,7 @@ __global__ void Marlin( using FragB = typename ScalarType::FragB; using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; constexpr int pack_factor = 32 / num_bits; @@ -566,6 +690,13 @@ __global__ void Marlin( int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -605,6 +736,19 @@ __global__ void Marlin( int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. @@ -616,6 +760,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. @@ -664,14 +820,17 @@ __global__ void Marlin( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 // Zero accumulators. auto zero_accums = [&]() { @@ -777,6 +936,28 @@ __global__ void Marlin( } } } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -784,6 +965,12 @@ __global__ void Marlin( cp_async_fence(); }; + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering @@ -932,8 +1119,73 @@ __global__ void Marlin( } }; + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + if constexpr (!has_zp) { + return; + } + + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + }; + // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + if constexpr (num_bits == 4) { + int zp_quant = frag_qzp[k % 2][0]; + int zp_quant_shift = zp_quant >> 8; + frag_zp_0 = dequant_4bit_zp(zp_quant); + frag_zp_1 = dequant_4bit_zp(zp_quant_shift); + + } else { + int zp_quant_0 = frag_qzp[k % 2][0]; + int zp_quant_1 = frag_qzp[k % 2][1]; + frag_zp_0 = dequant_8bit_zp(zp_quant_0); + frag_zp_1 = dequant_8bit_zp(zp_quant_1); + } + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -944,16 +1196,32 @@ __global__ void Marlin( int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); + if constexpr (has_zp) { + frag_b0 = dequant_4bit_zp(b_quant); + frag_b1 = dequant_4bit_zp(b_quant_shift); + + } else { + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + } } else { int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); + if constexpr (has_zp) { + frag_b0 = dequant_8bit_zp(b_quant_0); + frag_b1 = dequant_8bit_zp(b_quant_1); + } else { + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + } + + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); } // Apply scale to frag_b0 @@ -967,6 +1235,11 @@ __global__ void Marlin( } } + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], @@ -1189,6 +1462,12 @@ __global__ void Marlin( } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } fetch_to_shared(i, i, i < slice_iters); } @@ -1197,6 +1476,7 @@ __global__ void Marlin( init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; @@ -1217,6 +1497,7 @@ __global__ void Marlin( for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -1354,6 +1635,7 @@ __global__ void Marlin( } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } start_pipes(); @@ -1363,22 +1645,24 @@ __global__ void Marlin( } #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ Marlin, \ + HAS_ZP, GROUP_BLOCKS>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ + HAS_ZP, GROUP_BLOCKS> \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ + prob_m, prob_n, prob_k, locks); \ } typedef struct { @@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + + #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, int max_par) { + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; @@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, thread_m_blocks = exec_cfg.max_m_blocks; } - // Define kernel configurations if (false) { } - CALL_IF(4, 32, 2, 256) - CALL_IF(4, 16, 4, 256) - CALL_IF(4, 8, 8, 256) - CALL_IF(4, 8, 4, 128) - CALL_IF(4, 4, 8, 128) - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) + GPTQ_CALL_IF(4, 16, 4, 256) + GPTQ_CALL_IF(4, 8, 8, 256) + GPTQ_CALL_IF(4, 8, 4, 128) + GPTQ_CALL_IF(4, 4, 8, 128) + GPTQ_CALL_IF(8, 16, 4, 256) + GPTQ_CALL_IF(8, 8, 8, 256) + GPTQ_CALL_IF(8, 8, 4, 128) + GPTQ_CALL_IF(8, 4, 8, 128) + + AWQ_CALL_IF(4, 16, 4, 256) + AWQ_CALL_IF(4, 8, 8, 256) + AWQ_CALL_IF(4, 8, 4, 128) + AWQ_CALL_IF(4, 4, 8, 128) + AWQ_CALL_IF(8, 16, 4, 256) + AWQ_CALL_IF(8, 8, 8, 256) + AWQ_CALL_IF(8, 8, 4, 128) + AWQ_CALL_IF(8, 4, 8, 128) else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; @@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp) { // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", marlin::tile_size); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + " is not divisible by tile_size = ", marlin::tile_size); + int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); @@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int group_size = -1; bool has_act_order = g_idx.size(0) != 0; - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); num_groups = b_scales.size(0); @@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); + TORCH_CHECK(b_zeros.size(0) == num_groups, + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", b_scales.size(1), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { - gptq_marlin::marlin_mm_f16i4( + marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, gptq_marlin::max_par); + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { - gptq_marlin::marlin_mm_f16i4( + marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, - is_k_full, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/server/marlin/marlin_kernels/gptq_marlin_repack.cu b/server/marlin/marlin_kernels/gptq_marlin_repack.cu index 4adc158e..c71b1bf5 100644 --- a/server/marlin/marlin_kernels/gptq_marlin_repack.cu +++ b/server/marlin/marlin_kernels/gptq_marlin_repack.cu @@ -1,23 +1,16 @@ -#include "gptq_marlin.cuh" - -namespace gptq_marlin { - -static constexpr int repack_stages = 8; - -static constexpr int repack_threads = 256; - -static constexpr int tile_k_size = tile_size; -static constexpr int tile_n_size = tile_k_size * 4; +#include "marlin.cuh" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +namespace marlin { + template -__global__ void marlin_repack_kernel( +__global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {} -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, @@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, #else +namespace marlin { + template -__global__ void marlin_repack_kernel( +__global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { @@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel( } } -} // namespace gptq_marlin +} // namespace marlin - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); - TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", marlin::tile_k_size); + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); - torch::Tensor out = - torch::empty({size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / pack_factor}, - options); + torch::Tensor out = torch::empty( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; diff --git a/server/marlin/marlin_kernels/gptq_marlin.cuh b/server/marlin/marlin_kernels/marlin.cuh similarity index 88% rename from server/marlin/marlin_kernels/gptq_marlin.cuh rename to server/marlin/marlin_kernels/marlin.cuh index 42af4495..74ccbac5 100644 --- a/server/marlin/marlin_kernels/gptq_marlin.cuh +++ b/server/marlin/marlin_kernels/marlin.cuh @@ -9,7 +9,9 @@ #include #include -namespace gptq_marlin { +namespace marlin { + +// Marlin params // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, @@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers template struct Vec { T elems[n]; @@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace gptq_marlin +} // namespace marlin diff --git a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu index d124c014..37339b84 100644 --- a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu +++ b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu @@ -30,7 +30,7 @@ inline std::string str(T x) { return std::to_string(x); } -namespace marlin { +namespace marlin_dense { constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } @@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, } } -} // namespace marlin +} // namespace marlin_dense torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, @@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(size_k == a.size(1), "Shape mismatch: a.size(1) = " + str(a.size(1)) + ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin::tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(marlin::tile_size)); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin_dense::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_dense::tile_size)); + TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin::tile_size)); + ", tile_size = " + str(marlin_dense::tile_size)); // Verify N TORCH_CHECK(b_scales.size(1) == size_n, "b_scales.size(1) = " + str(b_scales.size(1)) + ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin::tile_size)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_dense::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - int actual_size_n = - (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; + int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * + marlin_dense::pack_factor_4bit; TORCH_CHECK( size_n == actual_size_n, "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); @@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "Unexpected groupsize = " + str(groupsize)); // Verify workspace size - TORCH_CHECK( - size_n % marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_dense::min_thread_n)); + int min_workspace_size = + (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = " + str(workspace.numel()) + " is below min_workspace_size = " + str(min_workspace_size)); int dev = a.get_device(); - marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, marlin::max_par); + marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, marlin_dense::max_par); return c; } diff --git a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh b/server/marlin/marlin_kernels/marlin_dtypes.cuh similarity index 93% rename from server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh rename to server/marlin/marlin_kernels/marlin_dtypes.cuh index ca1b7099..be06c09b 100644 --- a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh +++ b/server/marlin/marlin_kernels/marlin_dtypes.cuh @@ -1,11 +1,11 @@ #ifndef _data_types_cuh #define _data_types_cuh -#include "gptq_marlin.cuh" +#include "marlin.cuh" #include #include -namespace gptq_marlin { +namespace marlin { template class ScalarType {}; @@ -23,6 +23,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragZP = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); @@ -51,6 +52,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragZP = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { @@ -72,6 +74,6 @@ class ScalarType { #endif }; -} // namespace gptq_marlin +} // namespace marlin #endif diff --git a/server/marlin/setup.py b/server/marlin/setup.py index cc38bccf..f92eef04 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -9,6 +9,7 @@ setup( CUDAExtension( name="marlin_kernels", sources=[ + "marlin_kernels/awq_marlin_repack.cu", "marlin_kernels/fp8_marlin.cu", "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 6feca275..6d7495b7 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader): f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - g_idx = weights.get_tensor(f"{prefix}.g_idx") + if not self.sym: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") scales = weights.get_tensor(f"{prefix}.scales") return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) @@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader): quantize=self.quantize, sym=self.sym, ): - g_idx = weights.get_tensor(f"{prefix}.g_idx") + if not self.sym: + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) @@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader): quantize=self.quantize, sym=self.sym, ): - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] + + if not self.sym: + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) @@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader): f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if not self.sym: + if self.desc_act or self.groupsize == -1: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: scales = weights.get_tensor(f"{prefix}.scales") else: @@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader): return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=sharded_in_features, ) diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a28012da..cf622cc2 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union +import numpy import torch import torch.nn as nn from loguru import logger @@ -174,11 +175,12 @@ def can_use_gptq_marlin( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} and bits in GPTQ_MARLIN_BITS and groupsize in GPTQ_MARLIN_GROUP_SIZES - and sym + # We only suppord asymmetric quantization for AWQ. + and (sym or quant_method == "awq") ) @@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight): """ qweight: torch.Tensor + qzeros: torch.Tensor scales: torch.Tensor g_idx: torch.Tensor perm: torch.Tensor @@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight): def repack_gptq_for_marlin( *, qweight: torch.Tensor, + qzeros: Optional[torch.Tensor], scales: torch.Tensor, - g_idx: torch.Tensor, + g_idx: Optional[torch.Tensor], bits: int, desc_act: bool, groupsize: int, + quant_method: str, sym: bool, sharded_infeatures: bool, ) -> GPTQMarlinWeight: @@ -279,30 +284,54 @@ def repack_gptq_for_marlin( raise RuntimeError( f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" ) - if not sym: + if not (sym or quant_method == "awq"): raise RuntimeError( "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." ) + log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") + weights_per_int = 32 // bits - in_features = qweight.shape[0] * weights_per_int + in_features = qweight.shape[0] out_features = qweight.shape[1] + # AWQ uses column packing, GPTQ uses row packing + if quant_method == "awq": + out_features *= weights_per_int + else: + in_features *= weights_per_int + if in_features % groupsize != 0: raise ValueError( f"Number of input features ({in_features}) not divisible by group size ({groupsize})" ) - if desc_act and groupsize != -1: + if g_idx is not None and desc_act and groupsize != -1: perm = torch.argsort(g_idx).to(torch.int) g_idx = g_idx[perm] else: perm = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) + if quant_method == "awq": + repacked = marlin_kernels.awq_marlin_repack( + qweight, in_features, out_features, bits + ) + if qzeros is not None: + qzeros = awq_to_marlin_zero_points( + qzeros, + in_features // groupsize, + out_features, + bits, + ) + + else: + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, bits + ) + + if qzeros is None: + qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) scales = permute_scales(scales) @@ -310,6 +339,7 @@ def repack_gptq_for_marlin( return GPTQMarlinWeight( qweight=repacked, + qzeros=qzeros, scales=scales, g_idx=g_idx, perm=perm, @@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module): self.is_full_k = weight.is_full_k self.qweight = weight.qweight + self.qzeros = weight.qzeros self.scales = weight.scales self.g_idx = weight.g_idx self.perm = weight.perm @@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module): A_flat, self.qweight, self.scales, + self.qzeros, self.g_idx, self.perm, self.workspace, @@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module): self.scales.shape[1], A_flat.shape[1], self.is_full_k, + self.qzeros.numel() > 0, ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) @@ -688,3 +721,116 @@ class MarlinLinear(nn.Module): C += self.bias return C + + +# Functions below are from vLLM + + +def get_pack_factor(bits: int) -> int: + if 32 % bits != 0: + raise ValueError(f"Cannot {bits} bit values into uint32") + return 32 // bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index c3c038fe..98dae079 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision): groupsize = -1 quant_method = "gptq" checkpoint_format = None - sym = True + sym = False desc_act = False filename = "config.json" @@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision): activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + if "zero_point" in data["quantization_config"]: + sym = not data["quantization_config"]["zero_point"] + quant_method = "awq" + elif "sym" in data["quantization_config"]: + sym = data["quantization_config"]["sym"] + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - sym = data["quantization_config"]["sym"] desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" @@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision): data = json.load(f) bits = data["bits"] groupsize = data["group_size"] - sym = data["sym"] + + if "zero_point" in data: + sym = not data["zero_point"] + quant_method = "awq" + elif "sym" in data: + sym = data["sym"] + desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": quant_method = "awq" From 69b67b7add5bf7fa028e93dfc4d5bf43387a71f1 Mon Sep 17 00:00:00 2001 From: shaltielshmid Date: Tue, 23 Jul 2024 16:00:07 +0300 Subject: [PATCH 155/340] Add support for Mistral-Nemo by supporting head_dim through config (#2254) * Support passing head_dim through config * Using `head_dim` as a fallback is necessary since it's a non standard key in mistralConfig (as defined in transformers). * Shorter diff. --------- Co-authored-by: Nicolas Patry --- .../models/custom_modeling/flash_mistral_modeling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index eeb3c45f..eca01bbb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -149,15 +149,14 @@ class MistralAttention(torch.nn.Module): bias=False, ) - head_size = config.hidden_size // config.num_attention_heads self.query_key_value = TensorParallelMultiAdapterLinear.load( query_key_value, layer_id, ["q_proj", "k_proj", "v_proj"], sizes=[ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, ], process_group=weights.process_group, ) From 5390973c09d1ca8981fac2e5fc849cc7cf4af4b7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Jul 2024 16:20:17 +0200 Subject: [PATCH 156/340] Preparing for release. (#2285) * Preparing for release. * Updating docs. * Fixing token within the docker image for the launcher. --- docs/source/installation_amd.md | 2 +- docs/source/installation_intel.md | 2 +- docs/source/installation_nvidia.md | 2 +- docs/source/quicktour.md | 4 ++-- launcher/src/main.rs | 13 +++++++++++-- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 33d85732..931a9e3a 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1-rocm \ + ghcr.io/huggingface/text-generation-inference:2.2.0-rocm \ --model-id $model ``` diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md index f9fda863..b3843490 100644 --- a/docs/source/installation_intel.md +++ b/docs/source/installation_intel.md @@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:latest-intel \ + ghcr.io/huggingface/text-generation-inference:2.2.0-intel \ --model-id $model --cuda-graphs 0 ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 4de6cb19..dac37d79 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index f056baad..2313c69b 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $model ``` @@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:2.1.1 --help +docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help ``` diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 82dcf379..733cb1c1 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,7 +1,10 @@ /// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. use clap::{Parser, ValueEnum}; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, +}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use serde::Deserialize; @@ -1415,7 +1418,13 @@ fn main() -> Result<(), LauncherError> { let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let filename = if !path.exists() { // Assume it's a hub id - let api = Api::new()?; + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; let repo = if let Some(ref revision) = args.revision { api.repo(Repo::with_revision( model_id, From 43f49141fd71105b9f34592dcf4aeb4071c157f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 23 Jul 2024 17:18:54 +0200 Subject: [PATCH 157/340] Add support for Llama 3 rotary embeddings (#2286) * Add support for Llama 3 rotary embeddings * Update transformers to 4.43 --- server/poetry.lock | 227 +++++++++--------- server/pyproject.toml | 2 +- server/requirements_cuda.txt | 6 +- server/requirements_intel.txt | 6 +- server/requirements_rocm.txt | 6 +- .../text_generation_server/layers/rotary.py | 60 ++++- 6 files changed, 176 insertions(+), 131 deletions(-) diff --git a/server/poetry.lock b/server/poetry.lock index 3dbd24e2..33a83b00 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -794,74 +794,66 @@ setuptools = "*" [[package]] name = "hf-transfer" -version = "0.1.6" -description = "" +version = "0.1.8" +description = "Speed up file transfers with the Hugging Face Hub." optional = false python-versions = ">=3.7" files = [ - {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, - {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, - {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, - {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, - {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, - {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, - {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, - {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, - {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, - {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, - {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, - {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, - {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, - {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, - {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, - {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, - {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, - {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, - {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, - {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, - {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, - {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, - {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, - {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, - {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, - {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, - {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, + {file = "hf_transfer-0.1.8-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:70858f9e94286738ed300484a45beb5cfee6a7ddac4c5886f9c6fce7823ac5ab"}, + {file = "hf_transfer-0.1.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38adc73f0a8526319d90f7cc5dc2d5e4bb66f487a513d94b98aa6725be732e4a"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44d2f0c08198d8d899fe9d66e86aee2dd844bd7ce33888f261373fcec81d2a54"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1de2a4ef36f9e60b3d3bec00193c0aafd75771709f2ca51b9b162373f5af3d32"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e319269e3606a5ff2979296841766649ac73598a4a8eee2a968f86c8071fea5a"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f6026cf3be6a53ea42f92172f60c1c0675baaa9073f865e671b661dde5fd157"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f865c33ada5bd3650c2b46e59979f2d7755c3f517f8d0facc78576a0c7d26406"}, + {file = "hf_transfer-0.1.8-cp310-none-win32.whl", hash = "sha256:2054730e8d8ed21917c64be7199e06424b2bd08df1c43a72766afaed7992f2d3"}, + {file = "hf_transfer-0.1.8-cp310-none-win_amd64.whl", hash = "sha256:2b4f1a9446ba31170b5b1eca4e916504d18378a6b5fe959896bdac8a736a5ecb"}, + {file = "hf_transfer-0.1.8-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e27c15fcc5869ad7e52bbc0bdec6106b288d1c463f8d2da92f28615a3b181361"}, + {file = "hf_transfer-0.1.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:871a0032d011ebc6409a73a8406b98b84ff2cd3ed7d9e1af8cdf4d660b9fab9b"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:686fa756e1e0214bb6327d33c66732c52274d94a8460beb50604ad988b391cf6"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:36a03b1b2911b0cf15b1b9d971a34b32dadcc4f2fd979aaff5979d6ce4017c34"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:079db90c81f41f4cf3227dfaaa855a9b8e9aef45bc7c2be29ce7232cd83ff881"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac08a4524127fdd14c234d4bcbe49d1c498acf5335c781714823179bcc8dc039"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837432e73cb17274a6782b6216e8ce058aa325a475dc44a5a6a753d48b86d18a"}, + {file = "hf_transfer-0.1.8-cp311-none-win32.whl", hash = "sha256:b180f9823dde35aba9bc0f1d0c04ac8a873baebd3732a7ffe4f11940abc7df0d"}, + {file = "hf_transfer-0.1.8-cp311-none-win_amd64.whl", hash = "sha256:37907d2135cebcf8b6d419bb575148d89c224f16b69357f027bd29d0e85c6529"}, + {file = "hf_transfer-0.1.8-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baf948f4f493949309cbe60529620b9b0aef854a22b6e526753364acc57c09b6"}, + {file = "hf_transfer-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bce5c8bdefa478c5d5eaa646cc4ce1df5cfe764d98572ad0c6b8773e98d49f6"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d6f8a1a86128d651a3799e1267c343d60f81f2c565d7c5416eb8e674e4cf0e"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f79fd1b0c2ed93efb4c5f684118d7a762ecdd218e170df8208c4e13d3dcd4959"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:414df35692670683bf5623498ef9d88a8df5d77e9516515da6e2b34d1054c11f"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c9798d5f951f66b96d40a7a53910260cb5874fda56cf5944dddb7c571f37ec3"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:060c661691f85a61392e57579c80eb64b5ee277434e81fb582f605c1c8ff05d5"}, + {file = "hf_transfer-0.1.8-cp312-none-win32.whl", hash = "sha256:f7840e32379820c3e1571a480238e05ea043e970c99d2e999578004a2eb17788"}, + {file = "hf_transfer-0.1.8-cp312-none-win_amd64.whl", hash = "sha256:9a3204ec423cc5e659872e8179f8704ad9ce2abb1e6a991f8838aedf1dc07830"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09949e86ad63ee139e463fd0dfaf401515ae70445854199f61d545514c65f744"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf1a74552845b93ea972e6e7131ef54e56056aa54137e93a40faf3fbcb2442ff"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959bcb3afb4ee6f2a07031a947dba98ec0b64c001bc914fbd8fc32e13a287162"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e01eecdb8162bd61dab9090fbd9f8034dd8b5755ef727a21ca8a057f80cb91ee"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50650a38e9d31f5ad8f010e4598bf304ecd99c17162e7d93f67e031571b864ee"}, + {file = "hf_transfer-0.1.8-cp37-none-win32.whl", hash = "sha256:e29b9d1d378138f2f4eae0e93ca94af3b5d45f4532eef69f1ab97fe06f9c9d9e"}, + {file = "hf_transfer-0.1.8-cp37-none-win_amd64.whl", hash = "sha256:cfd6cef43ae883103117a371f8ebae4e7f9637bc6fb480f1be5568e2fe22a8a7"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92a68f7a0043cca8a0de4decc760dca177530944cbab502afac503bd1b2fa01a"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e3138e408179f80a5480598e32f8e1abb564915cbde4d3bc8da52811c75dc3ea"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4544d148930ad34442d43b8fa911c8479c04a95b858b1d1f91e0b7da77082fad"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a851794b9f029965664f8c3002c957fccf21685e9397ceb4f9f19c986dee8ad3"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:791aaf87c5319ac83edb6ab2994b3db19924c49d6ff667dd3d8a610b455ff70a"}, + {file = "hf_transfer-0.1.8-cp38-none-win32.whl", hash = "sha256:8f71e5d35d3a3160dcca12fdcc8119033aeacaa6a32838a7ad9f9cb1008bbe58"}, + {file = "hf_transfer-0.1.8-cp38-none-win_amd64.whl", hash = "sha256:543287b4ceb1e25501580b99690f7f0df9d3631d29306f37cbd97e918c732944"}, + {file = "hf_transfer-0.1.8-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:7ce02a18bd0bb2343e707ac85b68c946bc37623ee24150c69158f6b2b2c7a98f"}, + {file = "hf_transfer-0.1.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64d7f8dbd64ba183ed1df75d47c84e075ff666ceaa335bff1de16b09eaac5b80"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e7858694e11419ae27e542fb8fc0d0e54d46ff7768fe73bc359d70b8f5aa578"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed116cd9d1edfa32c0136d7cb8e5f1afd2b32df43c49085d428f108fc8e1c8f"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e385d0da9c6b3472ab29285d2d46c9f9903205b8d108f88a82f3f85aafae0ab"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98f75fa4b86ef15433cd907807ac77d1fb39d7e7b790bfd39c7ae9c385bf0200"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a63ad947d2901425ac0a3ed70c3696dfde27fadb0482ed763bdd5cc946b278"}, + {file = "hf_transfer-0.1.8-cp39-none-win32.whl", hash = "sha256:3e74096915813ae842ea6a5bdf10c0fef960aa51a35a560955b3e61cdfe3db57"}, + {file = "hf_transfer-0.1.8-cp39-none-win_amd64.whl", hash = "sha256:05ea16307bf4a5eb097cbc6e5057e4eb5e080a138af23ef639fd38857723c288"}, + {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:928ff036c3e98e10dcfbdb4fcdfc4592d37a5cc8e365a7ba8dfd4337e849d675"}, + {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d49ba3ce67035f460ae1924fe2feafec155cb535eec7f31ed5109c19064cd294"}, + {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b01f5872c62cfee3ec9ca5c738818296f69f8adf84b4d8d15f2a5601d9dda339"}, + {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:659d4212d50847a5165666bf43d67727679b4f694ef9c413613cc27093136527"}, + {file = "hf_transfer-0.1.8.tar.gz", hash = "sha256:26d229468152e7a3ec12664cac86b8c2800695fd85f9c9a96677a775cc04f0b3"}, ] [[package]] @@ -1384,47 +1376,56 @@ numpy = ">=1.22,<2.1" [[package]] name = "numpy" -version = "1.26.4" +version = "2.0.1" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.9" files = [ - {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, - {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, - {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, - {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, - {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, - {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, - {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, - {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, - {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, - {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, - {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, - {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, - {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, - {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, - {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, - {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, - {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, - {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, - {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, - {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, - {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, - {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, - {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, - {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, - {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, - {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, - {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, - {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, - {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, - {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, - {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, - {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, - {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, + {file = "numpy-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0fbb536eac80e27a2793ffd787895242b7f18ef792563d742c2d673bfcb75134"}, + {file = "numpy-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:69ff563d43c69b1baba77af455dd0a839df8d25e8590e79c90fcbe1499ebde42"}, + {file = "numpy-2.0.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1b902ce0e0a5bb7704556a217c4f63a7974f8f43e090aff03fcf262e0b135e02"}, + {file = "numpy-2.0.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:f1659887361a7151f89e79b276ed8dff3d75877df906328f14d8bb40bb4f5101"}, + {file = "numpy-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4658c398d65d1b25e1760de3157011a80375da861709abd7cef3bad65d6543f9"}, + {file = "numpy-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4127d4303b9ac9f94ca0441138acead39928938660ca58329fe156f84b9f3015"}, + {file = "numpy-2.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e5eeca8067ad04bc8a2a8731183d51d7cbaac66d86085d5f4766ee6bf19c7f87"}, + {file = "numpy-2.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9adbd9bb520c866e1bfd7e10e1880a1f7749f1f6e5017686a5fbb9b72cf69f82"}, + {file = "numpy-2.0.1-cp310-cp310-win32.whl", hash = "sha256:7b9853803278db3bdcc6cd5beca37815b133e9e77ff3d4733c247414e78eb8d1"}, + {file = "numpy-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:81b0893a39bc5b865b8bf89e9ad7807e16717f19868e9d234bdaf9b1f1393868"}, + {file = "numpy-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75b4e316c5902d8163ef9d423b1c3f2f6252226d1aa5cd8a0a03a7d01ffc6268"}, + {file = "numpy-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6e4eeb6eb2fced786e32e6d8df9e755ce5be920d17f7ce00bc38fcde8ccdbf9e"}, + {file = "numpy-2.0.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a1e01dcaab205fbece13c1410253a9eea1b1c9b61d237b6fa59bcc46e8e89343"}, + {file = "numpy-2.0.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a8fc2de81ad835d999113ddf87d1ea2b0f4704cbd947c948d2f5513deafe5a7b"}, + {file = "numpy-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a3d94942c331dd4e0e1147f7a8699a4aa47dffc11bf8a1523c12af8b2e91bbe"}, + {file = "numpy-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15eb4eca47d36ec3f78cde0a3a2ee24cf05ca7396ef808dda2c0ddad7c2bde67"}, + {file = "numpy-2.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b83e16a5511d1b1f8a88cbabb1a6f6a499f82c062a4251892d9ad5d609863fb7"}, + {file = "numpy-2.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f87fec1f9bc1efd23f4227becff04bd0e979e23ca50cc92ec88b38489db3b55"}, + {file = "numpy-2.0.1-cp311-cp311-win32.whl", hash = "sha256:36d3a9405fd7c511804dc56fc32974fa5533bdeb3cd1604d6b8ff1d292b819c4"}, + {file = "numpy-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:08458fbf403bff5e2b45f08eda195d4b0c9b35682311da5a5a0a0925b11b9bd8"}, + {file = "numpy-2.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bf4e6f4a2a2e26655717a1983ef6324f2664d7011f6ef7482e8c0b3d51e82ac"}, + {file = "numpy-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7d6fddc5fe258d3328cd8e3d7d3e02234c5d70e01ebe377a6ab92adb14039cb4"}, + {file = "numpy-2.0.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5daab361be6ddeb299a918a7c0864fa8618af66019138263247af405018b04e1"}, + {file = "numpy-2.0.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:ea2326a4dca88e4a274ba3a4405eb6c6467d3ffbd8c7d38632502eaae3820587"}, + {file = "numpy-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:529af13c5f4b7a932fb0e1911d3a75da204eff023ee5e0e79c1751564221a5c8"}, + {file = "numpy-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6790654cb13eab303d8402354fabd47472b24635700f631f041bd0b65e37298a"}, + {file = "numpy-2.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cbab9fc9c391700e3e1287666dfd82d8666d10e69a6c4a09ab97574c0b7ee0a7"}, + {file = "numpy-2.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99d0d92a5e3613c33a5f01db206a33f8fdf3d71f2912b0de1739894668b7a93b"}, + {file = "numpy-2.0.1-cp312-cp312-win32.whl", hash = "sha256:173a00b9995f73b79eb0191129f2455f1e34c203f559dd118636858cc452a1bf"}, + {file = "numpy-2.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:bb2124fdc6e62baae159ebcfa368708867eb56806804d005860b6007388df171"}, + {file = "numpy-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc085b28d62ff4009364e7ca34b80a9a080cbd97c2c0630bb5f7f770dae9414"}, + {file = "numpy-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8fae4ebbf95a179c1156fab0b142b74e4ba4204c87bde8d3d8b6f9c34c5825ef"}, + {file = "numpy-2.0.1-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:72dc22e9ec8f6eaa206deb1b1355eb2e253899d7347f5e2fae5f0af613741d06"}, + {file = "numpy-2.0.1-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:ec87f5f8aca726117a1c9b7083e7656a9d0d606eec7299cc067bb83d26f16e0c"}, + {file = "numpy-2.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f682ea61a88479d9498bf2091fdcd722b090724b08b31d63e022adc063bad59"}, + {file = "numpy-2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8efc84f01c1cd7e34b3fb310183e72fcdf55293ee736d679b6d35b35d80bba26"}, + {file = "numpy-2.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3fdabe3e2a52bc4eff8dc7a5044342f8bd9f11ef0934fcd3289a788c0eb10018"}, + {file = "numpy-2.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:24a0e1befbfa14615b49ba9659d3d8818a0f4d8a1c5822af8696706fbda7310c"}, + {file = "numpy-2.0.1-cp39-cp39-win32.whl", hash = "sha256:f9cf5ea551aec449206954b075db819f52adc1638d46a6738253a712d553c7b4"}, + {file = "numpy-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:e9e81fa9017eaa416c056e5d9e71be93d05e2c3c2ab308d23307a8bc4443c368"}, + {file = "numpy-2.0.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:61728fba1e464f789b11deb78a57805c70b2ed02343560456190d0501ba37b0f"}, + {file = "numpy-2.0.1-pp39-pypy39_pp73-macosx_14_0_x86_64.whl", hash = "sha256:12f5d865d60fb9734e60a60f1d5afa6d962d8d4467c120a1c0cda6eb2964437d"}, + {file = "numpy-2.0.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eacf3291e263d5a67d8c1a581a8ebbcfd6447204ef58828caf69a5e3e8c75990"}, + {file = "numpy-2.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2c3a346ae20cfd80b6cfd3e60dc179963ef2ea58da5ec074fd3d9e7a1e7ba97f"}, + {file = "numpy-2.0.1.tar.gz", hash = "sha256:485b87235796410c3519a699cfe1faab097e509e90ebb05dcd098db2ae87e7b3"}, ] [[package]] @@ -3077,19 +3078,19 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.42.4" +version = "4.43.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.42.4-py3-none-any.whl", hash = "sha256:6d59061392d0f1da312af29c962df9017ff3c0108c681a56d1bc981004d16d24"}, - {file = "transformers-4.42.4.tar.gz", hash = "sha256:f956e25e24df851f650cb2c158b6f4352dfae9d702f04c113ed24fc36ce7ae2d"}, + {file = "transformers-4.43.0-py3-none-any.whl", hash = "sha256:0158b430c3bab1abde3232598e250916edd906680b9ffe9d4e1b0b3fafbfc0ed"}, + {file = "transformers-4.43.0.tar.gz", hash = "sha256:4154e659eff31b0bf3b21b2c5d62e38ac8973e33c53cfe28284a19c8558234e2"}, ] [package.dependencies] filelock = "*" huggingface-hub = ">=0.23.2,<1.0" -numpy = ">=1.17,<2.0" +numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" @@ -3101,14 +3102,14 @@ tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] benchmark = ["optimum-benchmark (>=0.2.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -3131,15 +3132,15 @@ sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm (<=0.9.16)"] tokenizers = ["tokenizers (>=0.19,<0.20)"] torch = ["accelerate (>=0.21.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17,<2.0)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] @@ -3564,4 +3565,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "877137815e7080bf0f021c1a6c137b2d02042fc9340c862347aa07fde53b5f93" +content-hash = "ca4bcd76fba70266e0945883398a0ab5e38e751ea5e6776264cb6c5979cc2e40" diff --git a/server/pyproject.toml b/server/pyproject.toml index 0e9976ec..82d43af2 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.19.1" huggingface-hub = "^0.23" -transformers = "^4.42" +transformers = "^4.43" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index d074e406..47876068 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -11,12 +11,12 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index d074e406..47876068 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -11,12 +11,12 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index d074e406..47876068 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -11,12 +11,12 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index db78ee1c..e6e96abb 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,4 +1,5 @@ import os +import math import torch from torch import nn from loguru import logger @@ -85,9 +86,13 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor = None rope_scaling = _get_rope_config(config) if rope_scaling is not None: - if rope_scaling["type"] == "linear": + # `rope_type` is now standard in transformers, but some existing models + # have `type` instead. + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + + if rope_type == "linear": pass - elif rope_scaling["type"] == "dynamic": + elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( dim=dim, @@ -96,7 +101,20 @@ class PositionRotaryEmbedding(nn.Module): device=inv_freq.device, scaling_factor=scaling_factor, ) - elif rope_scaling["type"] == "yarn": + elif rope_type == "llama3": + inv_freq = apply_llama3_scaling( + inv_freq, + scaling_factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + ) + + return cls(inv_freq, scaling_factor) + + elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] mscale = rope_scaling.get("mscale", 1.0) mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) @@ -115,7 +133,7 @@ class PositionRotaryEmbedding(nn.Module): mscale=mscale, mscale_all_dim=mscale_all_dim, ) - elif rope_scaling["type"] in ["su", "longrope"]: + elif rope_type in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) @@ -327,10 +345,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): self._sin_cached = torch.sin(freqs).to(dtype) -# Inverse dim formula to find dim based on number of rotations -import math - - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) @@ -434,3 +448,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + + +def apply_llama3_scaling( + freqs: torch.Tensor, + *, + scaling_factor: int, + low_freq_factor: int, + high_freq_factor: int, + original_max_position_embeddings: int, +): + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scaling_factor) + else: + + assert low_freq_wavelen != high_freq_wavelen + smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) + + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) From b1077b077cf5a79460135d68b80c860e544361be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 23 Jul 2024 17:53:19 +0200 Subject: [PATCH 158/340] hotfix: pin numpy (#2289) --- server/poetry.lock | 91 ++++++++++++++++------------------- server/pyproject.toml | 2 + server/requirements_cuda.txt | 4 +- server/requirements_intel.txt | 4 +- server/requirements_rocm.txt | 4 +- 5 files changed, 49 insertions(+), 56 deletions(-) diff --git a/server/poetry.lock b/server/poetry.lock index 33a83b00..4dde0dab 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1376,56 +1376,47 @@ numpy = ">=1.22,<2.1" [[package]] name = "numpy" -version = "2.0.1" +version = "1.26.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.9" files = [ - {file = "numpy-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0fbb536eac80e27a2793ffd787895242b7f18ef792563d742c2d673bfcb75134"}, - {file = "numpy-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:69ff563d43c69b1baba77af455dd0a839df8d25e8590e79c90fcbe1499ebde42"}, - {file = "numpy-2.0.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1b902ce0e0a5bb7704556a217c4f63a7974f8f43e090aff03fcf262e0b135e02"}, - {file = "numpy-2.0.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:f1659887361a7151f89e79b276ed8dff3d75877df906328f14d8bb40bb4f5101"}, - {file = "numpy-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4658c398d65d1b25e1760de3157011a80375da861709abd7cef3bad65d6543f9"}, - {file = "numpy-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4127d4303b9ac9f94ca0441138acead39928938660ca58329fe156f84b9f3015"}, - {file = "numpy-2.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e5eeca8067ad04bc8a2a8731183d51d7cbaac66d86085d5f4766ee6bf19c7f87"}, - {file = "numpy-2.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9adbd9bb520c866e1bfd7e10e1880a1f7749f1f6e5017686a5fbb9b72cf69f82"}, - {file = "numpy-2.0.1-cp310-cp310-win32.whl", hash = "sha256:7b9853803278db3bdcc6cd5beca37815b133e9e77ff3d4733c247414e78eb8d1"}, - {file = "numpy-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:81b0893a39bc5b865b8bf89e9ad7807e16717f19868e9d234bdaf9b1f1393868"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75b4e316c5902d8163ef9d423b1c3f2f6252226d1aa5cd8a0a03a7d01ffc6268"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6e4eeb6eb2fced786e32e6d8df9e755ce5be920d17f7ce00bc38fcde8ccdbf9e"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a1e01dcaab205fbece13c1410253a9eea1b1c9b61d237b6fa59bcc46e8e89343"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a8fc2de81ad835d999113ddf87d1ea2b0f4704cbd947c948d2f5513deafe5a7b"}, - {file = "numpy-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a3d94942c331dd4e0e1147f7a8699a4aa47dffc11bf8a1523c12af8b2e91bbe"}, - {file = "numpy-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15eb4eca47d36ec3f78cde0a3a2ee24cf05ca7396ef808dda2c0ddad7c2bde67"}, - {file = "numpy-2.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b83e16a5511d1b1f8a88cbabb1a6f6a499f82c062a4251892d9ad5d609863fb7"}, - {file = "numpy-2.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f87fec1f9bc1efd23f4227becff04bd0e979e23ca50cc92ec88b38489db3b55"}, - {file = "numpy-2.0.1-cp311-cp311-win32.whl", hash = "sha256:36d3a9405fd7c511804dc56fc32974fa5533bdeb3cd1604d6b8ff1d292b819c4"}, - {file = "numpy-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:08458fbf403bff5e2b45f08eda195d4b0c9b35682311da5a5a0a0925b11b9bd8"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bf4e6f4a2a2e26655717a1983ef6324f2664d7011f6ef7482e8c0b3d51e82ac"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7d6fddc5fe258d3328cd8e3d7d3e02234c5d70e01ebe377a6ab92adb14039cb4"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5daab361be6ddeb299a918a7c0864fa8618af66019138263247af405018b04e1"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:ea2326a4dca88e4a274ba3a4405eb6c6467d3ffbd8c7d38632502eaae3820587"}, - {file = "numpy-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:529af13c5f4b7a932fb0e1911d3a75da204eff023ee5e0e79c1751564221a5c8"}, - {file = "numpy-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6790654cb13eab303d8402354fabd47472b24635700f631f041bd0b65e37298a"}, - {file = "numpy-2.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cbab9fc9c391700e3e1287666dfd82d8666d10e69a6c4a09ab97574c0b7ee0a7"}, - {file = "numpy-2.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99d0d92a5e3613c33a5f01db206a33f8fdf3d71f2912b0de1739894668b7a93b"}, - {file = "numpy-2.0.1-cp312-cp312-win32.whl", hash = "sha256:173a00b9995f73b79eb0191129f2455f1e34c203f559dd118636858cc452a1bf"}, - {file = "numpy-2.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:bb2124fdc6e62baae159ebcfa368708867eb56806804d005860b6007388df171"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc085b28d62ff4009364e7ca34b80a9a080cbd97c2c0630bb5f7f770dae9414"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8fae4ebbf95a179c1156fab0b142b74e4ba4204c87bde8d3d8b6f9c34c5825ef"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:72dc22e9ec8f6eaa206deb1b1355eb2e253899d7347f5e2fae5f0af613741d06"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:ec87f5f8aca726117a1c9b7083e7656a9d0d606eec7299cc067bb83d26f16e0c"}, - {file = "numpy-2.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f682ea61a88479d9498bf2091fdcd722b090724b08b31d63e022adc063bad59"}, - {file = "numpy-2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8efc84f01c1cd7e34b3fb310183e72fcdf55293ee736d679b6d35b35d80bba26"}, - {file = "numpy-2.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3fdabe3e2a52bc4eff8dc7a5044342f8bd9f11ef0934fcd3289a788c0eb10018"}, - {file = "numpy-2.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:24a0e1befbfa14615b49ba9659d3d8818a0f4d8a1c5822af8696706fbda7310c"}, - {file = "numpy-2.0.1-cp39-cp39-win32.whl", hash = "sha256:f9cf5ea551aec449206954b075db819f52adc1638d46a6738253a712d553c7b4"}, - {file = "numpy-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:e9e81fa9017eaa416c056e5d9e71be93d05e2c3c2ab308d23307a8bc4443c368"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:61728fba1e464f789b11deb78a57805c70b2ed02343560456190d0501ba37b0f"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-macosx_14_0_x86_64.whl", hash = "sha256:12f5d865d60fb9734e60a60f1d5afa6d962d8d4467c120a1c0cda6eb2964437d"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eacf3291e263d5a67d8c1a581a8ebbcfd6447204ef58828caf69a5e3e8c75990"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2c3a346ae20cfd80b6cfd3e60dc179963ef2ea58da5ec074fd3d9e7a1e7ba97f"}, - {file = "numpy-2.0.1.tar.gz", hash = "sha256:485b87235796410c3519a699cfe1faab097e509e90ebb05dcd098db2ae87e7b3"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] [[package]] @@ -3078,13 +3069,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.43.0" +version = "4.43.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.43.0-py3-none-any.whl", hash = "sha256:0158b430c3bab1abde3232598e250916edd906680b9ffe9d4e1b0b3fafbfc0ed"}, - {file = "transformers-4.43.0.tar.gz", hash = "sha256:4154e659eff31b0bf3b21b2c5d62e38ac8973e33c53cfe28284a19c8558234e2"}, + {file = "transformers-4.43.1-py3-none-any.whl", hash = "sha256:eb44b731902e062acbaff196ae4896d7cb3494ddf38275aa00a5fcfb5b34f17d"}, + {file = "transformers-4.43.1.tar.gz", hash = "sha256:662252c4d0e31b6684f68f68d5cc8206dd7f83da80eb3235be3dc5b3c9fdbdbd"}, ] [package.dependencies] @@ -3565,4 +3556,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "ca4bcd76fba70266e0945883398a0ab5e38e751ea5e6776264cb6c5979cc2e40" +content-hash = "52d2dda4603443ee76ce880f3f1fdf7f044e7ba6df9973d230f52b0fbc0dbf24" diff --git a/server/pyproject.toml b/server/pyproject.toml index 82d43af2..44f9cc13 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -37,6 +37,8 @@ pillow = "^10.0.0" outlines= { version = "^0.0.34", optional = true } prometheus-client = "^0.20.0" py-cpuinfo = "^9.0.0" +# Remove later, temporary workaround for outlines. +numpy = "^1.26" [tool.poetry.extras] torch = ["torch"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 47876068..828b6fca 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -16,7 +16,7 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 47876068..828b6fca 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -16,7 +16,7 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 47876068..828b6fca 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -16,7 +16,7 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" From 34c472bd64452b1faf8de4046acbf075fb9f71fc Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 23 Jul 2024 20:39:43 +0000 Subject: [PATCH 159/340] chore: update to torch 2.4 (#2259) * chore: update to torch 2.4 * remove un-necessary patch * fix --- server/Makefile | 1 - server/Makefile-fbgemm | 6 ++---- server/pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/server/Makefile b/server/Makefile index 209fc44e..af5ffa59 100644 --- a/server/Makefile +++ b/server/Makefile @@ -30,7 +30,6 @@ install: install-cuda install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm pip install -e ".[bnb]" - pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm index 38f8f31f..57526577 100644 --- a/server/Makefile-fbgemm +++ b/server/Makefile-fbgemm @@ -1,10 +1,8 @@ -fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca +fbgemm_commit := ddac8dd9fc0bee70a3f456df68b8aac38576c856 build-fbgemm: - chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \ git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ - cp fbgemm_remove_unused.patch fbgemm && \ - cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \ + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \ git submodule update --init --recursive && \ cd fbgemm_gpu && \ pip install -r requirements.txt && \ diff --git a/server/pyproject.toml b/server/pyproject.toml index 44f9cc13..f8c3cb8e 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -31,7 +31,7 @@ einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } peft = { version = "^0.10", optional = true } -torch = { version = "^2.3.0", optional = true } +torch = { version = "^2.4.0", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" outlines= { version = "^0.0.34", optional = true } From a994f6aeddb873a2da87e2643fbea9a8650ac640 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 23 Jul 2024 23:31:28 +0200 Subject: [PATCH 160/340] hotfix: update nccl --- server/Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/server/Makefile b/server/Makefile index af5ffa59..209fc44e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -30,6 +30,7 @@ install: install-cuda install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm pip install -e ".[bnb]" + pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm From 204142153f2a0030ec63c328bb8f737fa3bad475 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 24 Jul 2024 16:39:08 +0800 Subject: [PATCH 161/340] fix crash in multi-modal (#2245) * fix crash in multi-modal Signed-off-by: Wang, Yi A * update according to review comment Signed-off-by: Wang, Yi A * fix llava_next regression in latest main Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_llama_modeling.py | 2 +- .../text_generation_server/models/custom_modeling/idefics2.py | 1 + .../text_generation_server/models/custom_modeling/llava_next.py | 1 + server/text_generation_server/models/vlm_causal_lm.py | 2 ++ 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f7980d2d..3e8e67ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -424,7 +424,7 @@ class FlashLlamaModel(torch.nn.Module): FlashLlamaLayer( index=0, prefix=( - "model.layers.0" if not prefix else "{prefix}.model.layers.0" + "model.layers.0" if not prefix else f"{prefix}.model.layers.0" ), config=config, weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 735c3899..f540b7af 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -832,6 +832,7 @@ class Idefics2ForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 567131ef..e154d805 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -280,6 +280,7 @@ class LlavaNextForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 308d5a3d..7de54aa4 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -14,6 +14,7 @@ from text_generation_server.models.flash_causal_lm import ( ) from text_generation_server.utils.log import log_master from transformers import AutoProcessor +from text_generation_server.layers.attention import Seqlen tracer = trace.get_tracer(__name__) @@ -348,6 +349,7 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, From d93931567d861f822f6ebac178fce3b08c729377 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 24 Jul 2024 16:44:02 +0800 Subject: [PATCH 162/340] =?UTF-8?q?fix=20of=20use=20of=20unquantized=20wei?= =?UTF-8?q?ghts=20in=20cohere=20GQA=20loading,=20also=20enable=20=E2=80=A6?= =?UTF-8?q?=20(#2291)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix of use of unquantized weights in cohere GQA loading, also enable the model in intel platform Signed-off-by: Wang, Yi A --- .../custom_modeling/flash_cohere_modeling.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index c7b29d13..ab10dee1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight if SYSTEM == "cuda": import dropout_layer_norm @@ -83,6 +84,12 @@ class CohereRotary(PositionRotaryEmbedding): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." @@ -99,7 +106,7 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": + if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) @@ -166,16 +173,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.attention_bias: w = [ From 457791f511772d1456e7322c465cf39986dc35be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 24 Jul 2024 16:33:26 +0200 Subject: [PATCH 163/340] Split up `layers.marlin` into several files (#2292) The marlin.py file was getting large, split it up. --- server/tests/utils/test_weights.py | 5 +- .../text_generation_server/layers/marlin.py | 836 ------------------ .../layers/marlin/__init__.py | 20 + .../layers/marlin/fp8.py | 140 +++ .../layers/marlin/gptq.py | 266 ++++++ .../layers/marlin/marlin.py | 346 ++++++++ .../layers/marlin/util.py | 141 +++ 7 files changed, 917 insertions(+), 837 deletions(-) delete mode 100644 server/text_generation_server/layers/marlin.py create mode 100644 server/text_generation_server/layers/marlin/__init__.py create mode 100644 server/text_generation_server/layers/marlin/fp8.py create mode 100644 server/text_generation_server/layers/marlin/gptq.py create mode 100644 server/text_generation_server/layers/marlin/marlin.py create mode 100644 server/text_generation_server/layers/marlin/util.py diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index d2d2b76e..97e30c46 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -8,7 +8,10 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader -from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader +from text_generation_server.layers.marlin.marlin import ( + MarlinWeight, + MarlinWeightsLoader, +) from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py deleted file mode 100644 index cf622cc2..00000000 --- a/server/text_generation_server/layers/marlin.py +++ /dev/null @@ -1,836 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.fp8 import fp8_quantize -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -GPTQ_MARLIN_BITS = [4, 8] -GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] -MARLIN_TILE_SIZE = 16 - - -class MarlinWeightsLoader(WeightsLoader): - """Loader for Marlin-quantized weights.""" - - def __init__(self, *, bits: int, is_marlin_24: bool): - self.bits = bits - self.is_marlin_24 = is_marlin_24 - - def get_weights(self, weights: "Weights", prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = weights.get_tensor(f"{prefix}.B_24") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_tensor(f"{prefix}.B_meta") - s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = weights.get_tensor(f"{prefix}.B") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - s = weights.get_tensor(f"{prefix}.s") - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - if self.is_marlin_24: - B = weights.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = weights.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - B = weights.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - if self.is_marlin_24: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" - ) - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_row(self, weights: Weights, prefix: str): - if self.is_marlin_24: - try: - B = weights.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = weights.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - - return weight - - -def can_use_gptq_marlin( - *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool -) -> bool: - return ( - SYSTEM == "cuda" - and marlin_kernels is not None - and has_sm_8_0 - and quantize in {"awq", "gptq"} - and quant_method in {"awq", "gptq"} - and bits in GPTQ_MARLIN_BITS - and groupsize in GPTQ_MARLIN_GROUP_SIZES - # We only suppord asymmetric quantization for AWQ. - and (sym or quant_method == "awq") - ) - - -def _check_marlin_kernels(): - if not (SYSTEM == "cuda" and has_sm_8_0): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) - - if marlin_kernels is None: - raise NotImplementedError( - "marlin is not installed, install it with: pip install server/marlin" - ) - - -def _check_valid_shape(in_features: int, out_features: int): - if (in_features % 128 != 0 or out_features % 64 != 0) and ( - in_features % 64 != 0 or out_features % 128 != 0 - ): - raise ValueError( - f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." - " The shape elements must be divisible by (128, 64) or (64, 128)." - ) - - -# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 -def _get_perms() -> Tuple[List[int], List[int]]: - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -_scale_perm, _scale_perm_single = _get_perms() - - -def permute_scales(scales: torch.Tensor): - out_features = scales.shape[1] - if scales.shape[0] == 1: - scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] - else: - scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm] - return scales.reshape((-1, out_features)).contiguous() - - -@dataclass -class GPTQMarlinWeight(Weight): - """ - Repacked GPTQ Marlin weights. - """ - - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - bits: int - is_full_k: bool - - def __post_init__(self): - assert self.qweight.dtype == torch.int32 - assert self.scales.dtype == torch.float16 - assert self.g_idx.dtype == torch.int32 - assert self.perm.dtype == torch.int32 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlinLinear( - weight=self, - bias=bias, - ) - - -def repack_gptq_for_marlin( - *, - qweight: torch.Tensor, - qzeros: Optional[torch.Tensor], - scales: torch.Tensor, - g_idx: Optional[torch.Tensor], - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - sym: bool, - sharded_infeatures: bool, -) -> GPTQMarlinWeight: - """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" - _check_marlin_kernels() - assert marlin_kernels is not None - - if bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" - ) - - if groupsize not in GPTQ_MARLIN_GROUP_SIZES: - supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) - raise RuntimeError( - f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" - ) - if not (sym or quant_method == "awq"): - raise RuntimeError( - "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." - ) - - log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") - - weights_per_int = 32 // bits - in_features = qweight.shape[0] - out_features = qweight.shape[1] - - # AWQ uses column packing, GPTQ uses row packing - if quant_method == "awq": - out_features *= weights_per_int - else: - in_features *= weights_per_int - - if in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({groupsize})" - ) - - if g_idx is not None and desc_act and groupsize != -1: - perm = torch.argsort(g_idx).to(torch.int) - g_idx = g_idx[perm] - else: - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - - if quant_method == "awq": - repacked = marlin_kernels.awq_marlin_repack( - qweight, in_features, out_features, bits - ) - if qzeros is not None: - qzeros = awq_to_marlin_zero_points( - qzeros, - in_features // groupsize, - out_features, - bits, - ) - - else: - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) - - if qzeros is None: - qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) - - scales = permute_scales(scales) - - is_full_k = not (desc_act and sharded_infeatures) - - return GPTQMarlinWeight( - qweight=repacked, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - bits=bits, - is_full_k=is_full_k, - ) - - -class GPTQMarlinLinear(nn.Module): - """ - Linear layer for GPTQ weights that were converted for the GPTQ-Marlin - kernels. - """ - - def __init__( - self, - *, - weight: GPTQMarlinWeight, - bias: Optional[torch.Tensor], - ): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE - out_features = weight.scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.bits = weight.bits - self.is_full_k = weight.is_full_k - - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx - self.perm = weight.perm - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.perm, - self.workspace, - self.bits, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, - self.qzeros.numel() > 0, - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - - -@dataclass -class GPTQMarlin24Weight: - """ - GPTQ-Marlin 2:4 weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - B_meta (torch.Tensor): metadata for 2:4 sparsity. - s (torch.Tensor): float16 scales. - bits: quantized weight size. - """ - - B: torch.Tensor - B_meta: torch.Tensor - s: torch.Tensor - bits: int - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.B_meta.dtype == torch.int16 - assert self.s.dtype == torch.float16 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlin24Linear( - weight=self, - bias=bias, - ) - - -class GPTQMarlin24Linear(nn.Module): - def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - if weight.bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" - ) - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.s.shape[1] - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - - if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: - supported_sizes = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ) - raise RuntimeError( - f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" - ) - - self.bits = weight.bits - weights_per_int32 = 32 // self.bits - - assert ( - out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" - assert ( - out_features % weights_per_int32 == 0 - ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" - - assert ( - in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" - if groupsize != -1 and in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisable by group size ({groupsize})" - ) - - self.B = weight.B - self.B_meta = weight.B_meta - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, - dtype=torch.int, - device=weight.B.device, - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.gptq_marlin_24_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.B_meta, - self.s, - self.workspace, - self.bits, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -class GPTQMarlinFP8Linear(nn.Module): - """ - FP8 GPTQ-Marlin linear layer. - """ - - def __init__( - self, - qweight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> None: - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - - scales = scales.unsqueeze(0) - if scales.shape[1] == 1: - out_features, in_features = qweight.shape - scales = scales.repeat(1, out_features) - qweight, scales = repack_fp8_for_marlin(qweight, scales) - - in_features = qweight.shape[0] * MARLIN_TILE_SIZE - out_features = scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.qweight = qweight - self.scales = scales - self.bias = bias if bias is not None else None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=qweight.device - ) - - @classmethod - def from_unquant(cls, weight, bias, dtype): - qweight, scales = fp8_quantize(weight) - return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) - - @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, dtype): - return cls(qweight=weight, scales=scale.to(dtype), bias=bias) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.fp8_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.workspace, - 8, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements). - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - - if fp8_tensor.shape[0] % 4 != 0: - raise ValueError( - f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" - ) - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = torch.zeros( - fp8_tensor.shape[0] // 4, - fp8_tensor.shape[1], - dtype=torch.int32, - device=fp8_tensor.device, - ) - - for i in range(4): - packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) - - return packed - - -def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): - """ - Repack FP8 tensor for GPTQ-Marlin. - """ - - out_features, in_features = weight.shape - - # Torch linear layers weights with shape [out_features, in_features], - # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], - # so transpose before packing. - qweight = pack_fp8_as_int32(weight.t()) - - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, 8 - ) - - scales = permute_scales(scales) - - return repacked, scales - - -@dataclass -class MarlinWeight(Weight): - """ - Marlin weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): bfloat16/float16 scales. - """ - - B: torch.Tensor - s: torch.Tensor - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.s.dtype in [torch.float16, torch.bfloat16] - - def get_linear(self, bias: torch.Tensor): - return MarlinLinear(weight=self, bias=bias) - - -class MarlinLinear(nn.Module): - def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE - out_features = weight.s.shape[1] - assert ( - in_features % 128 == 0 - ), f"Number of input features ({in_features}) not divisable by 128" - assert ( - out_features % 256 == 0 - ), f"Number of output features ({out_features}) not divisable by 256" - - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert groupsize in { - -1, - 128, - }, f"Group size must be -1 or 128, was {groupsize}" - - self.B = weight.B - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.B.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.marlin_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.s, - self.workspace, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -# Functions below are from vLLM - - -def get_pack_factor(bits: int) -> int: - if 32 % bits != 0: - raise ValueError(f"Cannot {bits} bit values into uint32") - return 32 // bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp diff --git a/server/text_generation_server/layers/marlin/__init__.py b/server/text_generation_server/layers/marlin/__init__.py new file mode 100644 index 00000000..40147d59 --- /dev/null +++ b/server/text_generation_server/layers/marlin/__init__.py @@ -0,0 +1,20 @@ +from typing import List, Tuple + +import torch +from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear +from text_generation_server.layers.marlin.gptq import ( + GPTQMarlinLinear, + GPTQMarlinWeight, + can_use_gptq_marlin, + repack_gptq_for_marlin, +) +from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader + +__all__ = [ + "GPTQMarlinFP8Linear", + "GPTQMarlinLinear", + "GPTQMarlinWeight", + "MarlinWeightsLoader", + "can_use_gptq_marlin", + "repack_gptq_for_marlin", +] diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py new file mode 100644 index 00000000..fe55a58a --- /dev/null +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -0,0 +1,140 @@ +from typing import Optional + +import torch +import torch.nn as nn +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize +from text_generation_server.layers.marlin.gptq import _check_valid_shape +from text_generation_server.layers.marlin.util import ( + _check_marlin_kernels, + permute_scales, +) +from text_generation_server.utils.log import log_once + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + +MARLIN_TILE_SIZE = 16 + + +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + qweight: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + scales = scales.unsqueeze(0) + if scales.shape[1] == 1: + out_features, in_features = qweight.shape + scales = scales.repeat(1, out_features) + qweight, scales = repack_fp8_for_marlin(qweight, scales) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scales = fp8_quantize(weight) + return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + return cls(qweight=weight, scales=scale.to(dtype), bias=bias) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = permute_scales(scales) + + return repacked, scales diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py new file mode 100644 index 00000000..697634ea --- /dev/null +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -0,0 +1,266 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy +import torch +import torch.nn as nn +from loguru import logger +from text_generation_server.layers.marlin.util import ( + _check_marlin_kernels, + marlin_zero_points, + permute_scales, + unpack_cols, +) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +GPTQ_MARLIN_BITS = [4, 8] +GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] +MARLIN_TILE_SIZE = 16 + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: + return ( + SYSTEM == "cuda" + and marlin_kernels is not None + and has_sm_8_0 + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + # We only suppord asymmetric quantization for AWQ. + and (sym or quant_method == "awq") + ) + + +@dataclass +class GPTQMarlinWeight(Weight): + """ + Repacked GPTQ Marlin weights. + """ + + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + perm: torch.Tensor + bits: int + is_full_k: bool + + def __post_init__(self): + assert self.qweight.dtype == torch.int32 + assert self.scales.dtype == torch.float16 + assert self.g_idx.dtype == torch.int32 + assert self.perm.dtype == torch.int32 + + def get_linear(self, bias: torch.Tensor): + return GPTQMarlinLinear( + weight=self, + bias=bias, + ) + + +def repack_gptq_for_marlin( + *, + qweight: torch.Tensor, + qzeros: Optional[torch.Tensor], + scales: torch.Tensor, + g_idx: Optional[torch.Tensor], + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + sym: bool, + sharded_infeatures: bool, +) -> GPTQMarlinWeight: + """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" + _check_marlin_kernels() + assert marlin_kernels is not None + + if bits not in GPTQ_MARLIN_BITS: + supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) + raise RuntimeError( + f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" + ) + + if groupsize not in GPTQ_MARLIN_GROUP_SIZES: + supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) + raise RuntimeError( + f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" + ) + if not (sym or quant_method == "awq"): + raise RuntimeError( + "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." + ) + + log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") + + weights_per_int = 32 // bits + in_features = qweight.shape[0] + out_features = qweight.shape[1] + + # AWQ uses column packing, GPTQ uses row packing + if quant_method == "awq": + out_features *= weights_per_int + else: + in_features *= weights_per_int + + if in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisible by group size ({groupsize})" + ) + + if g_idx is not None and desc_act and groupsize != -1: + perm = torch.argsort(g_idx).to(torch.int) + g_idx = g_idx[perm] + else: + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) + + if quant_method == "awq": + repacked = marlin_kernels.awq_marlin_repack( + qweight, in_features, out_features, bits + ) + if qzeros is not None: + qzeros = awq_to_marlin_zero_points( + qzeros, + in_features // groupsize, + out_features, + bits, + ) + + else: + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, bits + ) + + if qzeros is None: + qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) + + scales = permute_scales(scales) + + is_full_k = not (desc_act and sharded_infeatures) + + return GPTQMarlinWeight( + qweight=repacked, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + perm=perm, + bits=bits, + is_full_k=is_full_k, + ) + + +class GPTQMarlinLinear(nn.Module): + """ + Linear layer for GPTQ weights that were converted for the GPTQ-Marlin + kernels. + """ + + def __init__( + self, + *, + weight: GPTQMarlinWeight, + bias: Optional[torch.Tensor], + ): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE + out_features = weight.scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.bits = weight.bits + self.is_full_k = weight.is_full_k + + self.qweight = weight.qweight + self.qzeros = weight.qzeros + self.scales = weight.scales + self.g_idx = weight.g_idx + self.perm = weight.perm + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.gptq_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.perm, + self.workspace, + self.bits, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + self.is_full_k, + self.qzeros.numel() > 0, + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def _check_valid_shape(in_features: int, out_features: int): + if (in_features % 128 != 0 or out_features % 64 != 0) and ( + in_features % 64 != 0 or out_features % 128 != 0 + ): + raise ValueError( + f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." + " The shape elements must be divisible by (128, 64) or (64, 128)." + ) diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py new file mode 100644 index 00000000..db3ce2d7 --- /dev/null +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -0,0 +1,346 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from text_generation_server.layers.marlin.util import _check_marlin_kernels +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + +class MarlinWeightsLoader(WeightsLoader): + """Loader for Marlin-quantized weights.""" + + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_row(self, weights: Weights, prefix: str): + if self.is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +@dataclass +class MarlinWeight(Weight): + """ + Marlin weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + s (torch.Tensor): bfloat16/float16 scales. + """ + + B: torch.Tensor + s: torch.Tensor + + def __post_init__(self): + assert self.B.dtype == torch.int32 + assert self.s.dtype in [torch.float16, torch.bfloat16] + + def get_linear(self, bias: torch.Tensor): + return MarlinLinear(weight=self, bias=bias) + + +class MarlinLinear(nn.Module): + def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE + out_features = weight.s.shape[1] + assert ( + in_features % 128 == 0 + ), f"Number of input features ({in_features}) not divisable by 128" + assert ( + out_features % 256 == 0 + ), f"Number of output features ({out_features}) not divisable by 256" + + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + assert groupsize in { + -1, + 128, + }, f"Group size must be -1 or 128, was {groupsize}" + + self.B = weight.B + self.s = weight.s + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=weight.B.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + C = marlin_kernels.marlin_gemm( + A.view(-1, A.shape[-1]), + self.B, + self.s, + self.workspace, + A.shape[0], + self.s.shape[1], + A.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_TILE_SIZE = 16 + + +@dataclass +class GPTQMarlin24Weight: + """ + GPTQ-Marlin 2:4 weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + B_meta (torch.Tensor): metadata for 2:4 sparsity. + s (torch.Tensor): float16 scales. + bits: quantized weight size. + """ + + B: torch.Tensor + B_meta: torch.Tensor + s: torch.Tensor + bits: int + + def __post_init__(self): + assert self.B.dtype == torch.int32 + assert self.B_meta.dtype == torch.int16 + assert self.s.dtype == torch.float16 + + def get_linear(self, bias: torch.Tensor): + return GPTQMarlin24Linear( + weight=self, + bias=bias, + ) + + +class GPTQMarlin24Linear(nn.Module): + def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + supported_bits = ", ".join( + str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + ) + raise RuntimeError( + f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" + ) + + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 + out_features = weight.s.shape[1] + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + + if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + supported_sizes = ", ".join( + str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ) + raise RuntimeError( + f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" + ) + + self.bits = weight.bits + weights_per_int32 = 32 // self.bits + + assert ( + out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 + ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" + assert ( + out_features % weights_per_int32 == 0 + ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" + + assert ( + in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 + ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" + if groupsize != -1 and in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisable by group size ({groupsize})" + ) + + self.B = weight.B + self.B_meta = weight.B_meta + self.s = weight.s + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, + dtype=torch.int, + device=weight.B.device, + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + C = marlin_kernels.gptq_marlin_24_gemm( + A.view(-1, A.shape[-1]), + self.B, + self.B_meta, + self.s, + self.workspace, + self.bits, + A.shape[0], + self.s.shape[1], + A.shape[1], + ) + + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C diff --git a/server/text_generation_server/layers/marlin/util.py b/server/text_generation_server/layers/marlin/util.py new file mode 100644 index 00000000..250d1714 --- /dev/null +++ b/server/text_generation_server/layers/marlin/util.py @@ -0,0 +1,141 @@ +import functools +from typing import List, Tuple + +import numpy +import torch +from text_generation_server.utils.import_utils import SYSTEM + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +def _check_marlin_kernels(): + if not (SYSTEM == "cuda" and has_sm_8_0): + raise NotImplementedError( + "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." + ) + + if marlin_kernels is None: + raise NotImplementedError( + "marlin is not installed, install it with: pip install server/marlin" + ) + + +# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 +@functools.cache +def get_perms() -> Tuple[List[int], List[int]]: + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def permute_scales(scales: torch.Tensor): + scale_perm, scale_perm_single = get_perms() + out_features = scales.shape[1] + if scales.shape[0] == 1: + scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + else: + scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] + return scales.reshape((-1, out_features)).contiguous() + + +# Functions below are from vLLM + + +def get_pack_factor(bits: int) -> int: + if 32 % bits != 0: + raise ValueError(f"Cannot {bits} bit values into uint32") + return 32 // bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + scale_perm, _ = get_perms() + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp From 7ebee37641a66f998fedb63b40806489cea4a85f Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 24 Jul 2024 15:32:14 -0400 Subject: [PATCH 164/340] fix: refactor adapter weight loading and mapping (#2193) * fix: refactor adapter weight loading and mapping * feat: enable lora load from directory * fix: adjust launcher for local lora adapters * feat: improve weight loading and add tests * fix: improve logging and rebase syntax issue * fix: impove adapter merge comments and remove unused conditional * fix: improve get_model_with_lora_adapters naming * fix: comment typo --- launcher/src/main.rs | 4 + server/tests/utils/test_adapter.py | 187 ++++++++++++++++++ .../text_generation_server/adapters/config.py | 13 +- .../text_generation_server/adapters/lora.py | 60 +++--- .../adapters/weights.py | 16 +- server/text_generation_server/cli.py | 34 ++-- server/text_generation_server/layers/lora.py | 5 +- .../text_generation_server/models/__init__.py | 127 +++++++++++- .../models/flash_causal_lm.py | 70 ------- .../models/flash_mistral.py | 85 -------- server/text_generation_server/models/model.py | 148 +------------- server/text_generation_server/server.py | 40 +--- .../text_generation_server/utils/adapter.py | 158 ++++++++++++--- 13 files changed, 498 insertions(+), 449 deletions(-) create mode 100644 server/tests/utils/test_adapter.py delete mode 100644 server/text_generation_server/models/flash_mistral.py diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 733cb1c1..3d3b068e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1621,6 +1621,10 @@ fn main() -> Result<(), LauncherError> { // Download and convert lora adapters if any if let Some(lora_adapters) = &args.lora_adapters { for adapter in lora_adapters.split(',') { + // skip download if a path is provided + if adapter.contains('=') { + continue; + } download_convert_model( adapter, None, diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py new file mode 100644 index 00000000..cc1b076d --- /dev/null +++ b/server/tests/utils/test_adapter.py @@ -0,0 +1,187 @@ +import pytest +from unittest.mock import Mock +from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights + + +def test_get_attn_weights(): + # create a mock layer + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + # call the function + result = get_attn_weights(2, mock_layer) + + # assert the result + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_with_gate_up_proj(): + # create a mock layer with gate_up_proj + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # call the function + result = get_mlp_weights(3, mock_layer) + + # assert the result + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_mlp_weights_without_gate_up_proj(): + # create a mock layer without gate_up_proj + mock_layer = Mock() + mock_layer.mlp = Mock(spec=[]) + + # call the function + result = get_mlp_weights(1, mock_layer) + + # assert the result + assert result == {} + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_attn_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(layer_index, mock_layer) + + for k in ["q", "k", "v"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.self_attn.{k}_proj" + ) + + assert (layer_index, "o_proj") in result + assert ( + result[(layer_index, "o_proj")][0] + == f"model.layers.{layer_index}.self_attn.o_proj" + ) + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_mlp_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(layer_index, mock_layer) + + for k in ["gate", "up", "down"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.mlp.{k}_proj" + ) + + +def test_get_attn_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_attn_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_proj = Mock() + mock_layer.mlp.up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # ensure that the mock_layer.mlp.gate_up_proj attribute does not exist. + # This is necessary because the use of `Mock` automatically creates any + # attributes that are accessed, even if they don't exist in the actual + # implementation. If `gate_up_proj` were created, `get_mlp_weights` might + # follow the wrong execution path and return an incorrect result. + del mock_layer.mlp.gate_up_proj + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py index 5261d4b5..2ee53b12 100644 --- a/server/text_generation_server/adapters/config.py +++ b/server/text_generation_server/adapters/config.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Set, Tuple import torch @@ -31,14 +31,3 @@ class AdapterConfig(ABC): weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: pass - - @abstractmethod - def load_batched_adapter_weights( - self, - model: "Model", - module_map: ModuleMap, - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 87543be2..ac143bb7 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig): adapter_weight_names.add(lora_b_name) return module_map, adapter_weight_names - def load_batched_adapter_weights( - self, - model: "Model", - module_map: Dict[str, Dict], - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - return LoraWeights.load( - self, - model, - module_map, - layer_type, - unused_weight_names, - ) - @classmethod def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) @@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights): def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: return [BatchLoraWeights] + # prepare pre-loaded lora weights for use in the model. + # + # this method processes and organizes lora weights for a specific layer type across all layers: + # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor. + # - retrieves weights from `module_map` based on the `layer_type`. + # - processes `nlayers` number of layers. + # - converts weights to the specified `dtype`. + # - shards weights across `world_size` number of processes using the `process_group`. + # - maps weights to specific layers using `target_to_layer`. + # - tracks `unused_weight_names` to identify any unused weights. + # + # the method handles weight transposition, scaling, and padding to ensure compatibility + # with SGMV or BGMV operations. @classmethod - def load( + def prepare_weights( cls, config: LoraConfig, - model: "Model", module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], + nlayers: int, + dtype: torch.dtype, + world_size: int, + process_group: ProcessGroup, + target_to_layer: Dict[str, Tuple[str, torch.Tensor]], ) -> Optional[AdapterWeights]: - nlayers = model.get_num_layers_for_type(layer_type) lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers for layer_id in range(nlayers): key = (layer_id, layer_type) - weight_name, layer = model.target_to_layer[key] + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device @@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights): return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, model.dtype) + lora_a = lora_a.to(base_device, dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, model.dtype) + lora_b = lora_b.to(base_device, dtype) scale = get_scaling_factor( config.lora_alpha, @@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [ - pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list - ] - lora_b_list = [ - pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list - ] + lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] + lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] if lora_a_list: # update rank if it was padded @@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights): *shard_lora_weights( weights_a=lora_a_list, weights_b=lora_b_list, - split_dim=0 if model.is_row_parallel(layer_type) else 1, - process_group=model.process_group, + split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1, + process_group=process_group, ), config, ) @@ -293,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights): for rank_data in self.rank_data.values() ) - @classmethod - def key(cls) -> str: - return "lora" - @classmethod def load( self, diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 8f658756..da75dbcd 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC): def has_adapter(self, adapter_index: int) -> bool: pass - @abstractclassmethod - def key(cls) -> str: - pass - @abstractclassmethod def load( cls, @@ -71,13 +67,6 @@ class LayerAdapterWeights: return del self.adapter_weights[adapter_idx] - @property - def max_speculative_tokens(self) -> int: - return max( - adapter_weights.speculative_tokens - for adapter_weights in self.adapter_weights.values() - ) - def is_empty(self) -> bool: return len(self.adapter_weights) == 0 @@ -101,7 +90,7 @@ class LayerAdapterWeights: adapter_weights, meta, prefill, prefill_head_indices ) if batched_weights is not None: - batch_data[batch_type.key()] = batched_weights + batch_data = batched_weights return batch_data @@ -133,8 +122,7 @@ class AdapterBatchData: def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation ranks = set() - for layer_data in self.data.values(): - lora_data = layer_data.get("lora") + for lora_data in self.data.values(): if lora_data is None: continue diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 8ec2a5ae..8b5520bf 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -4,9 +4,10 @@ import typer from pathlib import Path from loguru import logger -from typing import Optional +from typing import Optional, List, Dict from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.log import log_master @@ -80,28 +81,19 @@ def serve( if otlp_endpoint is not None: setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) - lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) - - # split on comma and strip whitespace - lora_adapter_ids = ( - [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] - ) - - if len(lora_adapter_ids) > 0: - log_master( - logger.warning, - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", - ) + lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user - if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: - log_master( - logger.warning, - f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", - ) - global CUDA_GRAPHS - CUDA_GRAPHS = None + if lora_adapters: + logger.warning("LoRA adapters enabled (experimental feature).") + + if "CUDA_GRAPHS" in os.environ: + logger.warning( + "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs." + ) + global CUDA_GRAPHS + CUDA_GRAPHS = None # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value @@ -117,7 +109,7 @@ def serve( ) server.serve( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 0bb6db41..df5e92da 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -43,10 +43,7 @@ class LoraLinear(nn.Module): ) -> torch.Tensor: if adapter_data is None: return result - data = adapter_data.data.get(layer_type) - data: Optional["BatchLoraWeights"] = ( - data.get("lora") if data is not None else None - ) + data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) if has_sgmv() and data is not None and data.can_vectorize(self.process_group): # In tensor-parallel configurations, each GPU processes a specific segment of the output. diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1cd13a2a..ac7a8f3e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,7 +6,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional, List +from typing import Optional, List, Dict from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -33,6 +33,16 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( T5ForConditionalGeneration, ) + +from text_generation_server.utils.adapter import ( + AdapterParameters, + build_layer_weight_lookup, + load_and_merge_adapters, + AdapterInfo, +) +from text_generation_server.adapters.lora import LoraWeights + + from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_master @@ -50,7 +60,7 @@ __all__ = [ "Model", "CausalLM", "Seq2SeqLM", - "get_model", + "get_model_with_lora_adapters", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -1115,3 +1125,116 @@ def get_model( ) raise ValueError(f"Unsupported model type {model_type}") + + +# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters +# this provides a post model loading hook to load adapters into the model after the model has been loaded +def get_model_with_lora_adapters( + model_id: str, + lora_adapters: Optional[List[AdapterInfo]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[str], + trust_remote_code: bool, + max_input_tokens: int, + adapter_to_index: Dict[str, int], +): + lora_adapter_ids = [adapter.id for adapter in lora_adapters] + model = get_model( + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + max_input_tokens, + ) + + if len(lora_adapters) > 0: + target_to_layer = build_layer_weight_lookup(model.model) + + for index, adapter in enumerate(lora_adapters): + # The AdapterParameters object allows for merging multiple adapters into a single adapter. + # At the moment, we only support loading a single adapter into the model, but we keep the + # AdapterParameters object for easier extension in the future. + adapter_parameters = AdapterParameters( + adapter_info=[adapter], + # when merging multiple adapters we can weight them differently + # if this is not set, all adapters will be weighted equally + # see: text_generation_server.utils.merges.strategies for impl + weights=None, + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + + adapter_index = index + 1 + adapter_to_index[adapter.id] = adapter_index + + logger.info( + f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}" + ) + weight_names = tuple([v[0] for v in target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + model.model_id, + adapter_parameters, + adapter_index, + weight_names, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + + adapter_layers = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + + for layer_name in adapter_layers: + nlayers = ( + 1 if layer_name == "lm_head" else len(model.model.model.layers) + ) + adapter_weights = LoraWeights.prepare_weights( + config=adapter_config, + module_map=module_map, + layer_type=layer_name, + unused_weight_names=unused_weight_names, + nlayers=nlayers, + dtype=model.dtype, + world_size=model.world_size, + process_group=model.process_group, + target_to_layer=target_to_layer, + ) + + if adapter_weights is None: + continue + + model.layer_to_adapter_weights[layer_name].add_adapter( + adapter_index, adapter_weights + ) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + model.loaded_adapters.add(adapter_index) + + return model diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5db62431..2ca40959 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,7 +1,6 @@ import math import os import time -import itertools import torch import torch.distributed @@ -1700,72 +1699,3 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py deleted file mode 100644 index 2b2bd2e0..00000000 --- a/server/text_generation_server/models/flash_mistral.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM - - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashMistral(FlashCausalLM): - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index e7748bb9..159139de 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -4,20 +4,12 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from collections import defaultdict -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights -from text_generation_server.utils.adapter import ( - load_and_merge_adapters, - AdapterParameters, - AdapterSource, -) -from text_generation_server.utils.log import log_master -from loguru import logger - BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -61,7 +53,6 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -142,140 +133,3 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) - - @property - def supports_adapter_loading(self) -> bool: - return False - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - return {} - - @property - def adapter_layers(self) -> List[str]: - return [] - - @property - def default_traced_adapter_layers(self) -> List[str]: - return [] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 0 - - def is_row_parallel(self, layer_type: str) -> bool: - return False - - @property - def max_speculative_tokens(self) -> int: - return max( - [ - weights.max_speculative_tokens - for weights in self.layer_to_adapter_weights.values() - ], - default=0, - ) - - def load_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - api_token: str, - dynamic: bool = True, - ): - """Loads adapter weights from disk / host memory on the GPU. - - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are applied during the forward - pass and stored separately from the base model parameters. - """ - if self.target_to_layer is None: - self.target_to_layer = self.adapter_target_to_layer() - if adapter_index in self.loaded_adapters: - # Adapter already loaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if dynamic and not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - log_master( - logger.info, - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", - ) - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - ( - module_map, - adapter_config, - adapter_weight_names, - adapter_tokenizer, - ) = load_and_merge_adapters( - self.model_id, - adapter_parameters, - adapter_source, - adapter_index, - weight_names, - api_token, - False, - ) - - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - adapter_weights = adapter_config.load_batched_adapter_weights( - self, - module_map, - layer_name, - unused_weight_names, - dynamic, - ) - - if adapter_weights is None: - continue - - layer_weights = self.layer_to_adapter_weights[layer_name] - layer_weights.add_adapter(adapter_index, adapter_weights) - - if len(unused_weight_names) > 0: - log_master( - logger.warning, - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", - ) - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.loaded_adapters.add(adapter_index) - - def offload_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - ): - """Offloads the adapter weights from GPU to CPU or disk.""" - if adapter_index not in self.loaded_adapters: - # Adapter already offloaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - for layer_name in self.adapter_layers: - if layer_name in self.layer_to_adapter_weights: - self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) - - self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index aee287c6..7ac54603 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -9,11 +9,12 @@ from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Dict from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models import Model, get_model +from text_generation_server.models import Model, get_model_with_lora_adapters +from text_generation_server.utils.adapter import AdapterInfo try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -30,9 +31,6 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id, set_adapter_to_index -from text_generation_server.utils.adapter import ( - AdapterParameters, -) class SignalHandler: @@ -195,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -207,7 +205,7 @@ def serve( ): async def serve_inner( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -228,9 +226,9 @@ def serve( server_urls = [local_url] try: - model = get_model( + model = get_model_with_lora_adapters( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, @@ -238,29 +236,9 @@ def serve( dtype, trust_remote_code, max_input_tokens, + adapter_to_index, ) - if len(lora_adapter_ids) > 0: - for index, adapter_id in enumerate(lora_adapter_ids): - # TODO: improve non merged adapter loading and long term - # improve adapter loading as a whole - adapter_parameters = AdapterParameters( - adapter_ids=[adapter_id], - weights=None, # will be set to 1 - merge_strategy=0, - density=1.0, - majority_sign_method=0, - ) - adapter_index = index + 1 - adapter_to_index[adapter_id] = adapter_index - model.load_adapter( - adapter_parameters, - None, # adapter_source - adapter_index, - None, # api_token - False, # dynamic - ) - except Exception: logger.exception("Error when initializing model") raise @@ -297,7 +275,7 @@ def serve( asyncio.run( serve_inner( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 4e2492de..1db5f77b 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -5,12 +5,11 @@ import warnings from dataclasses import dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple, Optional, List from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -from text_generation_server.pb import generate_pb2 from text_generation_server.utils.merges.strategies import merge_adapters from text_generation_server.utils import hub @@ -24,9 +23,15 @@ if TYPE_CHECKING: BASE_MODEL_ADAPTER_ID = "__base_model__" +@dataclass +class AdapterInfo: + id: str + path: Optional[str] + + @dataclass class AdapterParameters: - adapter_ids: Tuple[str] + adapter_info: Tuple[AdapterInfo] weights: Tuple[float] merge_strategy: NotImplemented density: float @@ -40,37 +45,47 @@ class AdapterSource: revision: str +def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: + if not lora_adapters: + return [] + + adapter_list = [] + for adapter in lora_adapters.split(","): + parts = adapter.strip().split("=") + if len(parts) == 1: + adapter_list.append(AdapterInfo(id=parts[0], path=None)) + elif len(parts) == 2: + adapter_list.append(AdapterInfo(id=parts[0], path=parts[1])) + else: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + return adapter_list + + def load_and_merge_adapters( model_id: str, adapter_parameters: AdapterParameters, - adapter_source: str, adapter_index: int, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - if len(adapter_parameters.adapter_ids) == 1: + + if len(adapter_parameters.adapter_info) == 1: + adapter_info = next(iter(adapter_parameters.adapter_info)) return load_module_map( model_id, - adapter_parameters.adapter_ids[0], - adapter_source, + adapter_info.id, + adapter_info.path, weight_names, - api_token, trust_remote_code, ) - adapter_params = AdapterParametersContainer( - adapter_parameters, adapter_source, adapter_index - ) - return _load_and_merge( - model_id, adapter_params, weight_names, api_token, trust_remote_code - ) + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) + return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) @dataclass class AdapterParametersContainer: adapter_parameters: AdapterParameters - adapter_source: str adapter_index: int def __hash__(self) -> int: @@ -82,7 +97,6 @@ def _load_and_merge( model_id: str, adapter_params: AdapterParametersContainer, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: params = adapter_params.adapter_parameters @@ -90,17 +104,16 @@ def _load_and_merge( adapters_to_merge = [] merged_weight_names = set() tokenizer = None - for adapter_id in params.adapter_ids: - if adapter_id == BASE_MODEL_ADAPTER_ID: + for adapter in params.adapter_info: + if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( load_module_map( model_id, - adapter_id, - adapter_params.adapter_source, + adapter.id, + adapter.path, weight_names, - api_token, trust_remote_code, ) ) @@ -159,25 +172,28 @@ def check_architectures( def load_module_map( model_id: str, adapter_id: str, - adapter_source: str, + adapter_path: Optional[str], weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: revision = "main" - adapter_config = LoraConfig.load(adapter_id, api_token) - if adapter_config.base_model_name_or_path != model_id: + adapter_config = LoraConfig.load(adapter_path or adapter_id, None) + + if not adapter_path and adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) - adapter_filenames = hub._cached_adapter_weight_files( - adapter_id, revision=revision, extension=".safetensors" + adapter_filenames = ( + hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors") + if adapter_path + else hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) ) try: adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_config.config_path, - token=api_token, trust_remote_code=trust_remote_code, ) except Exception: @@ -194,3 +210,87 @@ def load_module_map( adapter_weights, weight_names ) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer + + +def get_attn_weights(i, layer): + qkv = layer.self_attn.query_key_value + weights = {} + + for k in ["q", "k", "v"]: + key = (i, f"{k}_proj") + value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) + weights[key] = value + + weights[(i, "o_proj")] = ( + f"model.layers.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + return weights + + +def get_mlp_weights(i, layer): + weights = {} + if hasattr(layer, "mlp"): + mlp = layer.mlp + if hasattr(mlp, "gate_up_proj"): + # handle combined gate_up_proj (e.g., for some LLaMA variants) + weights.update( + { + (i, "gate_proj"): ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_up_proj, + ), + (i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj), + } + ) + else: + # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma) + if hasattr(mlp, "gate_proj"): + weights[(i, "gate_proj")] = ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_proj, + ) + if hasattr(mlp, "up_proj"): + weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + + if hasattr(mlp, "down_proj"): + weights[(i, "down_proj")] = ( + f"model.layers.{i}.mlp.down_proj", + mlp.down_proj, + ) + + return weights + + +# build_layer_weight_lookup creates a mapping of model layers to their corresponding +# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples +# containing the weight tensor path and the actual layer object. This mapping is needed +# for the lora adapter to know which weights to update when applying the adapter. +def build_layer_weight_lookup(model): + if hasattr(model, "language_model"): + m = model.language_model.model + elif hasattr(model, "text_model"): + m = model.text_model.model + else: + m = model.model + + layer_weights = {} + + for i, layer in enumerate(m.layers): + attn_weights = get_attn_weights(i, layer) + mlp_weights = get_mlp_weights(i, layer) + + layer_weights.update(attn_weights) + layer_weights.update(mlp_weights) + + lm_head = None + if hasattr(m, "lm_head"): + lm_head = m.lm_head + elif hasattr(model, "lm_head"): + lm_head = model.lm_head + + if lm_head: + layer_weights[(0, "lm_head")] = ("lm_head", lm_head) + + return layer_weights From 69db13e5e52132fbf76d0f6ccfffa76aed992600 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 25 Jul 2024 11:21:17 +0200 Subject: [PATCH 165/340] Using g6 instead of g5. (#2281) * Using g6 instead of g5. * Update the idefics2 snapshot. --- .../test_idefics2/test_flash_idefics2_next_all_params.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json index 45601505..dab437b9 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json @@ -30,7 +30,7 @@ }, { "id": 264, - "logprob": -0.37573242, + "logprob": -0.38061523, "special": false, "text": " a" }, @@ -42,7 +42,7 @@ }, { "id": 4480, - "logprob": -0.3322754, + "logprob": -0.26782227, "special": false, "text": " feature" }, @@ -78,7 +78,7 @@ }, { "id": 13, - "logprob": 0.0, + "logprob": -0.10632324, "special": false, "text": "\n" } From 64ffd642fa3456ab809ecaaf8ce9b67c673817da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 25 Jul 2024 13:34:44 +0200 Subject: [PATCH 166/340] Some small fixes for the Torch 2.4.0 update (#2304) * Fix GPTQ autotune data type to be compatible with Torch 2.4.0 * Update poetry lock file * Fix small PaliGemma logprob differences after the torch update --- .../test_flash_pali_gemma_two_images.json | 16 +-- server/poetry.lock | 123 ++++++------------ .../layers/gptq/custom_autotune.py | 2 +- 3 files changed, 50 insertions(+), 91 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json index ab4f3015..e3b5575c 100644 --- a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json @@ -8,49 +8,49 @@ "tokens": [ { "id": 2502, - "logprob": -1.734375, + "logprob": -1.7890625, "special": false, "text": "image" }, { "id": 2196, - "logprob": -0.5756836, + "logprob": -0.53125, "special": false, "text": " result" }, { "id": 604, - "logprob": -0.007843018, + "logprob": -0.0077209473, "special": false, "text": " for" }, { "id": 12254, - "logprob": -1.7167969, + "logprob": -1.703125, "special": false, "text": " chicken" }, { "id": 611, - "logprob": -0.17053223, + "logprob": -0.21582031, "special": false, "text": " on" }, { "id": 573, - "logprob": -0.7626953, + "logprob": -0.734375, "special": false, "text": " the" }, { "id": 8318, - "logprob": -0.02709961, + "logprob": -0.026000977, "special": false, "text": " beach" }, { "id": 1, - "logprob": -0.20739746, + "logprob": -0.2109375, "special": true, "text": "" } diff --git a/server/poetry.lock b/server/poetry.lock index 4dde0dab..c993a441 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -931,20 +931,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "intel-openmp" -version = "2021.4.0" -description = "Intel OpenMP* Runtime Library" -optional = true -python-versions = "*" -files = [ - {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, -] - [[package]] name = "interegular" version = "0.3.3" @@ -1153,24 +1139,6 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] -[[package]] -name = "mkl" -version = "2021.4.0" -description = "Intel® oneAPI Math Kernel Library" -optional = true -python-versions = "*" -files = [ - {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, - {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, - {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, -] - -[package.dependencies] -intel-openmp = "==2021.*" -tbb = "==2021.*" - [[package]] name = "mpmath" version = "1.3.0" @@ -1465,12 +1433,13 @@ files = [ [[package]] name = "nvidia-cudnn-cu12" -version = "8.9.2.26" +version = "9.1.0.70" description = "cuDNN runtime libraries" optional = true python-versions = ">=3" files = [ - {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, + {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, + {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, ] [package.dependencies] @@ -2841,19 +2810,6 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] -[[package]] -name = "tbb" -version = "2021.13.0" -description = "Intel® oneAPI Threading Building Blocks (oneTBB)" -optional = true -python-versions = "*" -files = [ - {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, - {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, - {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, - {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, -] - [[package]] name = "texttable" version = "1.7.0" @@ -2995,44 +2951,43 @@ files = [ [[package]] name = "torch" -version = "2.3.1" +version = "2.4.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = true python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, - {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, - {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, - {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, - {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, - {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, - {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, - {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, - {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, - {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, - {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, - {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, - {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, - {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, - {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, - {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, - {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, - {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, - {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, - {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, + {file = "torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:4ed94583e244af51d6a8d28701ca5a9e02d1219e782f5a01dd401f90af17d8ac"}, + {file = "torch-2.4.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c4ca297b7bd58b506bfd6e78ffd14eb97c0e7797dcd7965df62f50bb575d8954"}, + {file = "torch-2.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2497cbc7b3c951d69b276ca51fe01c2865db67040ac67f5fc20b03e41d16ea4a"}, + {file = "torch-2.4.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:685418ab93730efbee71528821ff54005596970dd497bf03c89204fb7e3f71de"}, + {file = "torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e743adadd8c8152bb8373543964551a7cb7cc20ba898dc8f9c0cdbe47c283de0"}, + {file = "torch-2.4.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:7334325c0292cbd5c2eac085f449bf57d3690932eac37027e193ba775703c9e6"}, + {file = "torch-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:97730014da4c57ffacb3c09298c6ce05400606e890bd7a05008d13dd086e46b1"}, + {file = "torch-2.4.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f169b4ea6dc93b3a33319611fcc47dc1406e4dd539844dcbd2dec4c1b96e166d"}, + {file = "torch-2.4.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:997084a0f9784d2a89095a6dc67c7925e21bf25dea0b3d069b41195016ccfcbb"}, + {file = "torch-2.4.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc3988e8b36d1e8b998d143255d9408d8c75da4ab6dd0dcfd23b623dfb0f0f57"}, + {file = "torch-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3374128bbf7e62cdaed6c237bfd39809fbcfaa576bee91e904706840c3f2195c"}, + {file = "torch-2.4.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:91aaf00bfe1ffa44dc5b52809d9a95129fca10212eca3ac26420eb11727c6288"}, + {file = "torch-2.4.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cc30457ea5489c62747d3306438af00c606b509d78822a88f804202ba63111ed"}, + {file = "torch-2.4.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a046491aaf96d1215e65e1fa85911ef2ded6d49ea34c8df4d0638879f2402eef"}, + {file = "torch-2.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:688eec9240f3ce775f22e1e1a5ab9894f3d5fe60f3f586deb7dbd23a46a83916"}, + {file = "torch-2.4.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:3af4de2a618fb065e78404c4ba27a818a7b7957eaeff28c6c66ce7fb504b68b8"}, + {file = "torch-2.4.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:618808d3f610d5f180e47a697d4ec90b810953bb1e020f424b2ac7fb0884b545"}, + {file = "torch-2.4.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed765d232d23566052ba83632ec73a4fccde00b4c94ad45d63b471b09d63b7a7"}, + {file = "torch-2.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2feb98ac470109472fb10dfef38622a7ee08482a16c357863ebc7bc7db7c8f7"}, + {file = "torch-2.4.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:8940fc8b97a4c61fdb5d46a368f21f4a3a562a17879e932eb51a5ec62310cb31"}, ] [package.dependencies] filelock = "*" fsspec = "*" jinja2 = "*" -mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} @@ -3040,12 +2995,12 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.9.1)"] +optree = ["optree (>=0.11.0)"] [[package]] name = "tqdm" @@ -3137,17 +3092,21 @@ vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.3.1" +version = "3.0.0" description = "A language and compiler for custom Deep Learning operations" optional = true python-versions = "*" files = [ - {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, - {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, - {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, - {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, - {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, - {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -3155,8 +3114,8 @@ filelock = "*" [package.extras] build = ["cmake (>=3.20)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] -tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "typer" @@ -3556,4 +3515,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "52d2dda4603443ee76ce880f3f1fdf7f044e7ba6df9973d230f52b0fbc0dbf24" +content-hash = "c94bbdf8131750891fb3f7132066718534129d85a4c09126d8d01c2de6c72798" diff --git a/server/text_generation_server/layers/gptq/custom_autotune.py b/server/text_generation_server/layers/gptq/custom_autotune.py index 1eb40f1e..0388ef20 100644 --- a/server/text_generation_server/layers/gptq/custom_autotune.py +++ b/server/text_generation_server/layers/gptq/custom_autotune.py @@ -91,7 +91,7 @@ class Autotuner(triton.KernelInterface): kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 ) except triton.OutOfResources: - return (float("inf"), float("inf"), float("inf")) + return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) From d5e054342e1fbfbd89f6973cf07aaefb8262c479 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 25 Jul 2024 14:44:21 +0200 Subject: [PATCH 167/340] Fixing idefics on g6 tests. (#2306) --- .../test_flash_idefics2_next_all_params.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json index dab437b9..1fad0b96 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json @@ -30,19 +30,19 @@ }, { "id": 264, - "logprob": -0.38061523, + "logprob": -0.37573242, "special": false, "text": " a" }, { "id": 633, - "logprob": -0.09301758, + "logprob": -0.09161377, "special": false, "text": " new" }, { "id": 4480, - "logprob": -0.26782227, + "logprob": -0.26171875, "special": false, "text": " feature" }, @@ -78,7 +78,7 @@ }, { "id": 13, - "logprob": -0.10632324, + "logprob": 0.0, "special": false, "text": "\n" } From 1674f441d046160a7f4c210afb498c2c4e43ae60 Mon Sep 17 00:00:00 2001 From: Adrien Date: Thu, 25 Jul 2024 16:06:00 +0200 Subject: [PATCH 168/340] Fix registry name (#2307) --- server/fbgemm_remove_unused.patch | 306 ------------------------------ server/fix_torch90a.sh | 11 -- 2 files changed, 317 deletions(-) delete mode 100644 server/fbgemm_remove_unused.patch delete mode 100755 server/fix_torch90a.sh diff --git a/server/fbgemm_remove_unused.patch b/server/fbgemm_remove_unused.patch deleted file mode 100644 index ad6af811..00000000 --- a/server/fbgemm_remove_unused.patch +++ /dev/null @@ -1,306 +0,0 @@ -diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt -index 2244ea6f..96265a48 100644 ---- a/fbgemm_gpu/CMakeLists.txt -+++ b/fbgemm_gpu/CMakeLists.txt -@@ -94,14 +94,14 @@ endif() - # Build Experimental Modules - ################################################################################ - --if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) -- # TODO: Figure out NCCL/RCCL integration with ROCm -- add_subdirectory(experimental/example) --endif() -- --if(NOT FBGEMM_CPU_ONLY) -- add_subdirectory(experimental/gemm) --endif() -+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) -+# # TODO: Figure out NCCL/RCCL integration with ROCm -+# add_subdirectory(experimental/example) -+# endif() -+ -+# if(NOT FBGEMM_CPU_ONLY) -+# add_subdirectory(experimental/gemm) -+# endif() - - if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) - # CUTLASS currently doesn't build on ROCm and CK hasnt yet been added: -diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake -index c56773fe..0c0d349e 100644 ---- a/fbgemm_gpu/FbgemmGpu.cmake -+++ b/fbgemm_gpu/FbgemmGpu.cmake -@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources} - ################################################################################ - - set(fbgemm_gpu_sources_static_cpu -- codegen/training/forward/embedding_forward_split_cpu.cpp -- codegen/inference/embedding_forward_quantized_host_cpu.cpp -- codegen/training/backward/embedding_backward_dense_host_cpu.cpp -- codegen/utils/embedding_bounds_check_host_cpu.cpp -- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp -- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp -- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp -- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp -- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp -- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp -- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp -- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp -- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp -- src/input_combine_ops/input_combine_cpu.cpp -- src/layout_transform_ops/layout_transform_ops_cpu.cpp -+ # codegen/training/forward/embedding_forward_split_cpu.cpp -+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp -+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp -+ # codegen/utils/embedding_bounds_check_host_cpu.cpp -+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp -+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp -+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp -+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp -+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp -+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp -+ # src/input_combine_ops/input_combine_cpu.cpp -+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp - src/quantize_ops/quantize_ops_cpu.cpp - src/quantize_ops/quantize_ops_meta.cpp -- src/sparse_ops/sparse_ops_cpu.cpp -- src/sparse_ops/sparse_ops_meta.cpp -- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp -- src/split_embeddings_cache/linearize_cache_indices.cpp -- src/split_embeddings_cache/lfu_cache_populate_byte.cpp -- src/split_embeddings_cache/lru_cache_populate_byte.cpp -- src/split_embeddings_cache/lxu_cache.cpp -- src/split_embeddings_cache/split_embeddings_cache_ops.cpp -- codegen/training/index_select/batch_index_select_dim0_ops.cpp -- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) -+ # src/sparse_ops/sparse_ops_cpu.cpp -+ # src/sparse_ops/sparse_ops_meta.cpp -+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp -+ # src/split_embeddings_cache/linearize_cache_indices.cpp -+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp -+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp -+ # src/split_embeddings_cache/lxu_cache.cpp -+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp -+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp -+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) -+) - - if(NOT FBGEMM_CPU_ONLY) - list(APPEND fbgemm_gpu_sources_static_cpu -- codegen/inference/embedding_forward_quantized_host.cpp -- codegen/utils/embedding_bounds_check_host.cpp -- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp -- src/layout_transform_ops/layout_transform_ops_gpu.cpp -- src/memory_utils/memory_utils.cpp -- src/memory_utils/memory_utils_ops.cpp -- src/memory_utils/memory_utils_ops_cpu.cpp -- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp -- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp -+ # codegen/inference/embedding_forward_quantized_host.cpp -+ # codegen/utils/embedding_bounds_check_host.cpp -+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp -+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp -+ # src/memory_utils/memory_utils.cpp -+ # src/memory_utils/memory_utils_ops.cpp -+ # src/memory_utils/memory_utils_ops_cpu.cpp -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp - src/quantize_ops/quantize_ops_gpu.cpp -- src/sparse_ops/sparse_ops_gpu.cpp -- src/split_embeddings_utils/split_embeddings_utils.cpp -- src/split_embeddings_cache/split_embeddings_cache_ops.cu -- src/metric_ops/metric_ops_host.cpp -- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp -- src/input_combine_ops/input_combine_gpu.cpp -- codegen/training/index_select/batch_index_select_dim0_host.cpp) -+ # src/sparse_ops/sparse_ops_gpu.cpp -+ # src/split_embeddings_utils/split_embeddings_utils.cpp -+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu -+ # src/metric_ops/metric_ops_host.cpp -+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp -+ # src/input_combine_ops/input_combine_gpu.cpp -+ # codegen/training/index_select/batch_index_select_dim0_host.cpp) -+ ) - - if(NVML_LIB_PATH OR USE_ROCM) - message(STATUS "Adding merge_pooled_embeddings sources") -@@ -516,36 +518,36 @@ endif() - - if(NOT FBGEMM_CPU_ONLY) - set(fbgemm_gpu_sources_static_gpu -- codegen/utils/embedding_bounds_check.cu -- codegen/inference/embedding_forward_quantized_split_lookup.cu -- src/embedding_inplace_ops/embedding_inplace_update.cu -- src/histogram_binning_calibration_ops.cu -- src/input_combine_ops/input_combine.cu -- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu -- src/memory_utils/memory_utils.cu -- src/memory_utils/memory_utils_ops.cu -- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu -- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu -- src/jagged_tensor_ops/dense_to_jagged_forward.cu -- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu -- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu -- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu -- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu -- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu -- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu -- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu -- src/jagged_tensor_ops/jagged_softmax_backward.cu -- src/jagged_tensor_ops/jagged_softmax_forward.cu -- src/jagged_tensor_ops/jagged_tensor_ops.cu -- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu -- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu -- src/jagged_tensor_ops/jagged_unique_indices.cu -- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu -- src/layout_transform_ops/layout_transform_ops.cu -- src/metric_ops/metric_ops.cu -- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu -- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu -- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu -+ # codegen/utils/embedding_bounds_check.cu -+ # codegen/inference/embedding_forward_quantized_split_lookup.cu -+ # src/embedding_inplace_ops/embedding_inplace_update.cu -+ # src/histogram_binning_calibration_ops.cu -+ # src/input_combine_ops/input_combine.cu -+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu -+ # src/memory_utils/memory_utils.cu -+ # src/memory_utils/memory_utils_ops.cu -+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu -+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu -+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu -+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu -+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu -+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu -+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu -+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu -+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu -+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu -+ # src/jagged_tensor_ops/jagged_softmax_backward.cu -+ # src/jagged_tensor_ops/jagged_softmax_forward.cu -+ # src/jagged_tensor_ops/jagged_tensor_ops.cu -+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu -+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu -+ # src/jagged_tensor_ops/jagged_unique_indices.cu -+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu -+ # src/layout_transform_ops/layout_transform_ops.cu -+ # src/metric_ops/metric_ops.cu -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu -+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu -+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu - src/quantize_ops/quantize_bfloat16.cu - src/quantize_ops/quantize_fp8_rowwise.cu - src/quantize_ops/quantize_fused_8bit_rowwise.cu -@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY) - src/quantize_ops/quantize_msfp.cu - src/quantize_ops/quantize_padded_fp8_rowwise.cu - src/quantize_ops/quantize_mx.cu -- src/sparse_ops/sparse_async_cumsum.cu -- src/sparse_ops/sparse_block_bucketize_features.cu -- src/sparse_ops/sparse_bucketize_features.cu -- src/sparse_ops/sparse_batched_unary_embeddings.cu -- src/sparse_ops/sparse_compute_frequency_sequence.cu -- src/sparse_ops/sparse_expand_into_jagged_permute.cu -- src/sparse_ops/sparse_group_index.cu -- src/sparse_ops/sparse_index_add.cu -- src/sparse_ops/sparse_index_select.cu -- src/sparse_ops/sparse_invert_permute.cu -- src/sparse_ops/sparse_pack_segments_backward.cu -- src/sparse_ops/sparse_pack_segments_forward.cu -- src/sparse_ops/sparse_permute_1d.cu -- src/sparse_ops/sparse_permute_2d.cu -- src/sparse_ops/sparse_permute102.cu -- src/sparse_ops/sparse_permute_embeddings.cu -- src/sparse_ops/sparse_range.cu -- src/sparse_ops/sparse_reorder_batched_ad.cu -- src/sparse_ops/sparse_segment_sum_csr.cu -- src/sparse_ops/sparse_zipf.cu -- src/split_embeddings_cache/lfu_cache_find.cu -- src/split_embeddings_cache/lfu_cache_populate.cu -- src/split_embeddings_cache/lfu_cache_populate_byte.cu -- src/split_embeddings_cache/lru_cache_find.cu -- src/split_embeddings_cache/lru_cache_populate.cu -- src/split_embeddings_cache/lru_cache_populate_byte.cu -- src/split_embeddings_cache/lxu_cache.cu -- src/split_embeddings_cache/linearize_cache_indices.cu -- src/split_embeddings_cache/reset_weight_momentum.cu -- src/split_embeddings_utils/generate_vbe_metadata.cu -- src/split_embeddings_utils/get_infos_metadata.cu -- src/split_embeddings_utils/radix_sort_pairs.cu -- src/split_embeddings_utils/transpose_embedding_input.cu) -+ # src/sparse_ops/sparse_async_cumsum.cu -+ # src/sparse_ops/sparse_block_bucketize_features.cu -+ # src/sparse_ops/sparse_bucketize_features.cu -+ # src/sparse_ops/sparse_batched_unary_embeddings.cu -+ # src/sparse_ops/sparse_compute_frequency_sequence.cu -+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu -+ # src/sparse_ops/sparse_group_index.cu -+ # src/sparse_ops/sparse_index_add.cu -+ # src/sparse_ops/sparse_index_select.cu -+ # src/sparse_ops/sparse_invert_permute.cu -+ # src/sparse_ops/sparse_pack_segments_backward.cu -+ # src/sparse_ops/sparse_pack_segments_forward.cu -+ # src/sparse_ops/sparse_permute_1d.cu -+ # src/sparse_ops/sparse_permute_2d.cu -+ # src/sparse_ops/sparse_permute102.cu -+ # src/sparse_ops/sparse_permute_embeddings.cu -+ # src/sparse_ops/sparse_range.cu -+ # src/sparse_ops/sparse_reorder_batched_ad.cu -+ # src/sparse_ops/sparse_segment_sum_csr.cu -+ # src/sparse_ops/sparse_zipf.cu -+ # src/split_embeddings_cache/lfu_cache_find.cu -+ # src/split_embeddings_cache/lfu_cache_populate.cu -+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu -+ # src/split_embeddings_cache/lru_cache_find.cu -+ # src/split_embeddings_cache/lru_cache_populate.cu -+ # src/split_embeddings_cache/lru_cache_populate_byte.cu -+ # src/split_embeddings_cache/lxu_cache.cu -+ # src/split_embeddings_cache/linearize_cache_indices.cu -+ # src/split_embeddings_cache/reset_weight_momentum.cu -+ # src/split_embeddings_utils/generate_vbe_metadata.cu -+ # src/split_embeddings_utils/get_infos_metadata.cu -+ # src/split_embeddings_utils/radix_sort_pairs.cu -+ # src/split_embeddings_utils/transpose_embedding_input.cu) -+ ) - - set_source_files_properties(${fbgemm_gpu_sources_static_gpu} - PROPERTIES COMPILE_OPTIONS -diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt -index 01f1d6ab..a6b8d7a8 100644 ---- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt -+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt -@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories - ${THIRDPARTY}/json/include - ${NCCL_INCLUDE_DIRS}) - --set(attention_ops_sources -- src/attention/attention.cpp -- src/attention/gqa_attn_splitk.cu) -+# set(attention_ops_sources -+# src/attention/attention.cpp -+# src/attention/gqa_attn_splitk.cu) - - set(quantize_ops_sources - src/quantize/cutlass_extensions.cu - src/quantize/quantize.cu - src/quantize/quantize.cpp) - --set(comm_ops_sources -- src/comm/car.cu -- src/comm/car.cpp) -+# set(comm_ops_sources -+# src/comm/car.cu -+# src/comm/car.cpp) - - set(experimental_gen_ai_cpp_source_files -- ${attention_ops_sources} -+ # ${attention_ops_sources} - ${quantize_ops_sources} -- ${comm_ops_sources}) -+ # ${comm_ops_sources} -+) - - set_source_files_properties(${experimental_gen_ai_cpp_source_files} - PROPERTIES INCLUDE_DIRECTORIES diff --git a/server/fix_torch90a.sh b/server/fix_torch90a.sh deleted file mode 100755 index 5e444828..00000000 --- a/server/fix_torch90a.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -# This script is required to patch torch < 2.4 -# It adds the 90a cuda target (H100) -# This target is required to build FBGEMM kernels - -torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|') - -sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch -sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch -sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch From fc6d80fdb80e904428d8ad85f786d8cef0e3d82d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 26 Jul 2024 14:57:24 +0200 Subject: [PATCH 169/340] Support tied embeddings in 0.5B and 1.5B Qwen2 models (#2313) --- .../custom_modeling/flash_qwen2_modeling.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index a98709c5..f40d126b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module): class Qwen2Model(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() + + prefix = f"{prefix}.model" if prefix else "model" + process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - if not prefix: - prefix = "model" - else: - prefix = f"{prefix}.model" - self.model = Qwen2Model(prefix, config, weights) + + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + self.lm_head = SpeculativeHead.load( config, - prefix="lm_head", + prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) self.max_past = config.sliding_window From a87791d7c949cc41e8f079cb3a6c047839aae818 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 26 Jul 2024 10:29:09 -0400 Subject: [PATCH 170/340] feat: add ruff and resolve issue (#2262) * feat: add ruff and resolve issue * fix: update client exports and adjust after rebase * fix: adjust syntax to avoid circular import * fix: adjust client ruff settings * fix: lint and refactor import check and avoid model enum as global names * fix: improve fbgemm_gpu check and lints * fix: update lints * fix: prefer comparing model enum over str * fix: adjust lints and ignore specific rules * fix: avoid unneeded quantize check --- .pre-commit-config.yaml | 5 ++ clients/python/text_generation/__init__.py | 14 +++- .../python/text_generation/inference_api.py | 2 +- integration-tests/conftest.py | 3 +- integration-tests/models/test_chat_llama.py | 3 - .../models/test_flash_pali_gemma.py | 2 - integration-tests/models/test_idefics.py | 4 +- integration-tests/models/test_tools_llama.py | 11 ++- load_tests/orca.py | 2 +- server/tests/utils/test_hub.py | 2 - server/tests/utils/test_layers.py | 1 - server/tests/utils/test_weights.py | 12 +--- .../text_generation_server/adapters/config.py | 5 +- .../text_generation_server/adapters/lora.py | 5 +- server/text_generation_server/cli.py | 13 ++-- .../text_generation_server/layers/__init__.py | 14 ++++ .../layers/attention/__init__.py | 9 +++ .../layers/attention/cuda.py | 1 - .../layers/attention/flash_attn_triton.py | 7 +- .../layers/attention/rocm.py | 1 - .../layers/awq/quantize/qmodule.py | 1 - server/text_generation_server/layers/bnb.py | 1 - server/text_generation_server/layers/fp8.py | 15 +++- .../layers/gptq/__init__.py | 69 ++++++++----------- .../layers/gptq/quant_linear.py | 10 +-- .../layers/gptq/quantize.py | 13 ++-- .../layers/gptq/utils.py | 56 +++++++++++++++ .../text_generation_server/layers/linear.py | 2 - server/text_generation_server/layers/lora.py | 6 +- .../layers/marlin/__init__.py | 3 - .../layers/marlin/marlin.py | 6 +- .../text_generation_server/layers/rotary.py | 3 - .../layers/speculative.py | 2 +- .../layers/tensor_parallel.py | 12 +--- .../text_generation_server/models/__init__.py | 6 +- .../models/causal_lm.py | 4 +- .../models/custom_modeling/bloom_modeling.py | 2 +- .../models/custom_modeling/clip.py | 20 ++---- .../custom_modeling/flash_cohere_modeling.py | 1 - .../custom_modeling/flash_dbrx_modeling.py | 1 - .../flash_deepseek_v2_modeling.py | 6 ++ .../custom_modeling/flash_llama_modeling.py | 3 +- .../custom_modeling/flash_mistral_modeling.py | 2 - .../custom_modeling/flash_mixtral_modeling.py | 2 - .../flash_pali_gemma_modeling.py | 1 - .../custom_modeling/flash_qwen2_modeling.py | 1 - .../models/custom_modeling/idefics2.py | 4 +- .../idefics_image_processing.py | 3 +- .../custom_modeling/idefics_modeling.py | 17 ++--- .../custom_modeling/idefics_perceiver.py | 1 - .../custom_modeling/idefics_processing.py | 3 - .../models/custom_modeling/idefics_vision.py | 2 - .../models/custom_modeling/llava_next.py | 2 +- .../models/custom_modeling/mpt_modeling.py | 25 +++---- .../models/custom_modeling/neox_modeling.py | 14 +--- .../models/custom_modeling/phi_modeling.py | 2 +- .../models/custom_modeling/siglip.py | 29 ++------ .../models/custom_modeling/t5_modeling.py | 15 +++- .../models/custom_modeling/vlm.py | 2 +- .../models/flash_causal_lm.py | 2 +- .../models/galactica.py | 8 --- .../text_generation_server/models/idefics.py | 7 +- .../models/idefics_causal_lm.py | 4 +- server/text_generation_server/models/mamba.py | 10 ++- server/text_generation_server/models/model.py | 2 +- .../models/pali_gemma.py | 7 +- .../models/seq2seq_lm.py | 7 +- server/text_generation_server/models/types.py | 1 - server/text_generation_server/server.py | 2 +- .../utils/import_utils.py | 10 ++- .../utils/merges/strategies.py | 21 +++--- server/text_generation_server/utils/peft.py | 2 +- .../utils/quantization.py | 1 - server/text_generation_server/utils/tokens.py | 1 - 74 files changed, 266 insertions(+), 302 deletions(-) create mode 100644 server/text_generation_server/layers/gptq/utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45bc07a5..bb3879b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,3 +16,8 @@ repos: - id: fmt - id: cargo-check - id: clippy +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index d7a09c9e..ca783dcd 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -19,5 +19,15 @@ DEPRECATION_WARNING = ( "Please use the `InferenceClient` from the `huggingface_hub` package instead." ) -from text_generation.client import Client, AsyncClient -from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient +from text_generation.client import Client, AsyncClient # noqa E402 +from text_generation.inference_api import ( # noqa E402 + InferenceAPIClient, + InferenceAPIAsyncClient, +) + +__all__ = [ + "Client", + "AsyncClient", + "InferenceAPIClient", + "InferenceAPIAsyncClient", +] diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py index 93b0de8d..b3b98ed2 100644 --- a/clients/python/text_generation/inference_api.py +++ b/clients/python/text_generation/inference_api.py @@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: List[DeployedModel]: list of all currently deployed models """ resp = requests.get( - f"https://api-inference.huggingface.co/framework/text-generation-inference", + "https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=5, ) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 60146ad1..46a8769f 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -4,7 +4,6 @@ import json import math import os import random -import re import shutil import subprocess import sys @@ -271,7 +270,7 @@ class LauncherHandle: try: await self.client.generate("test") return - except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: + except (ClientConnectorError, ClientOSError, ServerDisconnectedError): time.sleep(1) raise RuntimeError("Health check failed") diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py index 10df6dbd..1f7a4a59 100644 --- a/integration-tests/models/test_chat_llama.py +++ b/integration-tests/models/test_chat_llama.py @@ -1,7 +1,4 @@ import pytest -import json - -from text_generation.types import GrammarType @pytest.fixture(scope="module") diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 3ead3150..52ecaed4 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -1,6 +1,4 @@ import pytest -import requests -import io import base64 diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index b7725f0b..eb573385 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] - assert ( - generated_texts[0] == " \nAssistant: A rooster stands" - ), f"{response.generated_text}" + assert generated_texts[0] == " \nAssistant: A rooster stands" assert len(generated_texts) == 4 assert generated_texts, all( [text == generated_texts[0] for text in generated_texts] diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 0af3f66a..f831990a 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -1,7 +1,4 @@ import pytest -import json - -from text_generation.types import GrammarType @pytest.fixture(scope="module") @@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna }, ], ) - assert response.choices[0].message.content == None + assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { "id": 0, @@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto( }, ], ) - assert response.choices[0].message.content == None + assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { "id": 0, @@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice( }, ], ) - assert response.choices[0].message.content == None + assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { "id": 0, @@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( stream=False, ) - assert responses.choices[0].message.content == None + assert responses.choices[0].message.content is None assert responses.choices[0].message.tool_calls == [ { "function": { diff --git a/load_tests/orca.py b/load_tests/orca.py index e607d27c..e445afd5 100644 --- a/load_tests/orca.py +++ b/load_tests/orca.py @@ -20,7 +20,7 @@ def main(): break with open("./small.json", "w") as f: - data = json.dump(conversations, f, indent=4) + json.dump(conversations, f, indent=4) if __name__ == "__main__": diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index 721820f5..291a41b0 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -1,11 +1,9 @@ import os -import requests import tempfile import pytest import huggingface_hub.constants -from huggingface_hub import hf_api import text_generation_server.utils.hub from text_generation_server.utils.hub import ( diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 1e3aaf6b..118540ee 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -2,7 +2,6 @@ import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) -from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 97e30c46..556fcea1 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -2,7 +2,6 @@ import pytest import torch from text_generation_server.utils.weights import ( DefaultWeightsLoader, - UnquantizedWeight, Weights, WeightsLoader, ) @@ -86,15 +85,6 @@ dummy_file_system = { ], dtype=torch.float32, ), - "weight.weight": torch.tensor( - [ - [1, 2], - [3, 4], - [5, 6], - [7, 8], - ], - dtype=torch.float32, - ), }, "test_get_weights_row": { "weight.weight": torch.tensor( @@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2(): prefix = "weight" try: - w = weights.get_multi_weights_col( + weights.get_multi_weights_col( prefixes=[prefix], dim=0, ) diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py index 2ee53b12..b7e27090 100644 --- a/server/text_generation_server/adapters/config.py +++ b/server/text_generation_server/adapters/config.py @@ -4,15 +4,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Set, Tuple +from typing import Dict, Set, Tuple import torch from text_generation_server.adapters.weights import AdapterWeights -if TYPE_CHECKING: - from text_generation_server.models.model import Model - @dataclass class ModuleMap: diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index ac143bb7..a00338e7 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch from peft import LoraConfig as _LoraConfig @@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import ( use_cutlass_shrink, ) -if TYPE_CHECKING: - from text_generation_server.models.model import Model - def get_start_stop_idxs_for_rank(offset, size, rank, world_size): block_size = size // world_size diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 8b5520bf..10aa3a3b 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -4,12 +4,11 @@ import typer from pathlib import Path from loguru import logger -from typing import Optional, List, Dict +from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download from text_generation_server.utils.adapter import parse_lora_adapters -from text_generation_server.utils.log import log_master app = typer.Typer() @@ -165,7 +164,7 @@ def download_weights( # currently by default we don't merge the weights with the base model if merge_lora: try: - adapter_config_filename = hf_hub_download( + hf_hub_download( model_id, revision=revision, filename="adapter_config.json" ) utils.download_and_unload_peft( @@ -285,9 +284,9 @@ def download_weights( if auto_convert: if not trust_remote_code: logger.warning( - f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " - f"Pickle files are unsafe and can essentially contain remote code execution!" - f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", + "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " + "Pickle files are unsafe and can essentially contain remote code execution!" + "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", ) logger.warning( @@ -319,7 +318,7 @@ def download_weights( # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) - except Exception as e: + except Exception: discard_names = [] # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files, discard_names) diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index 32c8d121..0000ca91 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -18,3 +18,17 @@ from text_generation_server.layers.lora import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) + +__all__ = [ + "get_linear", + "FastLinear", + "TensorParallelColumnLinear", + "TensorParallelRowLinear", + "TensorParallelEmbedding", + "SpeculativeHead", + "LoraLinear", + "TensorParallelMultiAdapterLinear", + "TensorParallelAdapterRowLinear", + "load_layer_norm", + "load_conv2d", +] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index c8bccefe..f9b1715e 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -13,3 +13,12 @@ elif SYSTEM == "ipex": from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") + + +__all__ = [ + "attention", + "paged_attention", + "reshape_and_cache", + "SUPPORTS_WINDOWING", + "Seqlen", +] diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 8ff51596..c0c4da4d 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -10,7 +10,6 @@ _PARTITION_SIZE = 512 try: from vllm._C import cache_ops - from vllm._C import ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" diff --git a/server/text_generation_server/layers/attention/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py index 3fe32231..3a6f9a73 100644 --- a/server/text_generation_server/layers/attention/flash_attn_triton.py +++ b/server/text_generation_server/layers/attention/flash_attn_triton.py @@ -747,11 +747,8 @@ class _attention(torch.autograd.Function): padded_d_model = 1 << (head_size - 1).bit_length() padded_d_model = max(padded_d_model, 16) - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) + def grid(META): + return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch encoded_softmax = None diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 54da63e8..3ebe492a 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck" try: from vllm._C import cache_ops - from vllm._C import ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index c859db1b..391371a5 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -1,6 +1,5 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py -import math from typing import Optional import torch import torch.nn as nn diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index aae2bd1a..791d9b6d 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from functools import lru_cache import bitsandbytes as bnb import torch diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 54550ee3..59b08b55 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -12,17 +12,26 @@ from text_generation_server.utils.weights import ( Weights, ) from text_generation_server.utils.log import log_master, log_once +import importlib.util + FBGEMM_MM_AVAILABLE = False FBGEMM_DYN_AVAILABLE = False -try: - import fbgemm_gpu.experimental.gen_ai + +def is_fbgemm_gpu_available(): + try: + return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None + except ModuleNotFoundError: + return False + + +if is_fbgemm_gpu_available(): if SYSTEM == "cuda": major, _ = torch.cuda.get_device_capability() FBGEMM_MM_AVAILABLE = major == 9 FBGEMM_DYN_AVAILABLE = major >= 8 -except (ImportError, ModuleNotFoundError): +else: log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 6d7495b7..4fb151a9 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass + @dataclass class GPTQWeight(Weight): @@ -55,7 +83,7 @@ class GPTQWeight(Weight): from text_generation_server.layers.gptq import ExllamaQuantLinear except ImportError: raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" ) return ExllamaQuantLinear(self, bias) @@ -73,45 +101,6 @@ class GPTQWeight(Weight): ) -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, - ) - from text_generation_server.layers.gptq.exllamav2 import ( - create_exllama_buffers, - set_device, - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, - ) - from text_generation_server.layers.gptq.exllama import ( - create_exllama_buffers, - set_device, - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass - -from text_generation_server.layers.gptq.quant_linear import QuantLinear - - class GPTQWeightsLoader(WeightsLoader): """ Loader for GPTQ- and AWQ-quantized weights. diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/quant_linear.py index f60758b6..cebf9925 100644 --- a/server/text_generation_server/layers/gptq/quant_linear.py +++ b/server/text_generation_server/layers/gptq/quant_linear.py @@ -206,10 +206,12 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): output = torch.empty( (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( input, qweight, diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 0271d913..42c986b0 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional +from text_generation_server.layers.gptq.utils import torch_snr_error from text_generation_server.utils.weights import DefaultWeightsLoader @@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -700,6 +701,8 @@ def sequential( pass def add_batch(name): + nonlocal gptq + def tmp(_, inp, out): gptq[name].add_batch(inp[0].data, out.data) diff --git a/server/text_generation_server/layers/gptq/utils.py b/server/text_generation_server/layers/gptq/utils.py new file mode 100644 index 00000000..cbc0f391 --- /dev/null +++ b/server/text_generation_server/layers/gptq/utils.py @@ -0,0 +1,56 @@ +import torch + + +# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py +def torch_snr_error( + y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean" +) -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + torch.Tensor: _description_ + """ + if y_pred.shape != y_real.shape: + raise ValueError( + f"Can not compute snr loss for tensors with different shape. " + f"({y_pred.shape} and {y_real.shape})" + ) + reduction = str(reduction).lower() + + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(0) + y_real = y_real.unsqueeze(0) + + y_pred = y_pred.flatten(start_dim=1) + y_real = y_real.flatten(start_dim=1) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == "mean": + return torch.mean(snr) + elif reduction == "sum": + return torch.sum(snr) + elif reduction == "none": + return snr + else: + raise ValueError("Unsupported reduction method.") diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index a97cc43a..12d7f83a 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index df5e92da..a4537b55 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -1,12 +1,8 @@ -import math -import os -from typing import TYPE_CHECKING, Optional, Tuple, List +from typing import TYPE_CHECKING, Optional, List import torch import torch.distributed -from accelerate import init_empty_weights from torch import nn -from torch.nn import functional as F from torch.distributed import ProcessGroup from text_generation_server.utils.sgmv import ( diff --git a/server/text_generation_server/layers/marlin/__init__.py b/server/text_generation_server/layers/marlin/__init__.py index 40147d59..8a143f96 100644 --- a/server/text_generation_server/layers/marlin/__init__.py +++ b/server/text_generation_server/layers/marlin/__init__.py @@ -1,6 +1,3 @@ -from typing import List, Tuple - -import torch from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.gptq import ( GPTQMarlinLinear, diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py index db3ce2d7..89ebaca6 100644 --- a/server/text_generation_server/layers/marlin/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch import torch.nn as nn @@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader): ) except RuntimeError: raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" + "Cannot load `marlin` weight, make sure the model is already quantized" ) B_meta = torch.cat( @@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader): ) except RuntimeError: raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" + "Cannot load `marlin` weight, make sure the model is already quantized" ) s = torch.cat( [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index e6e96abb..fc4a59b9 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,12 +2,9 @@ import os import math import torch from torch import nn -from loguru import logger - from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": - from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops diff --git a/server/text_generation_server/layers/speculative.py b/server/text_generation_server/layers/speculative.py index 4b977a56..cf8469b5 100644 --- a/server/text_generation_server/layers/speculative.py +++ b/server/text_generation_server/layers/speculative.py @@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module): except KeyError: try: speculator = MedusaHeadV1.load(config, prefix, weights) - except: + except Exception: speculator = MedusaHeadV2(config, prefix, weights) lm_head = None else: diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 9dddb8ae..13f12ef1 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -2,7 +2,6 @@ import torch from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear -from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "ipex": @@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer): # If the piece and LM head embeddings are shared, we have # non-quantized weights... weight = weights.get_tensor(f"{prefix}.weight") - except: + except Exception: # ...otherwise they are quantized. weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 @@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False - # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq", "marlin"]: - quantize = None - # See above, exl2 LM head can be quantized or not. - elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): - quantize = None - else: - quantize = config.quantize - return TensorParallelHead( get_linear(weight, bias=None), process_group=weights.process_group, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ac7a8f3e..3dc24159 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,3 +1,6 @@ +# ruff: noqa: F821 +# the above line disables the `undefined-name` rule for the model type variables + import torch import enum import os @@ -712,6 +715,7 @@ def get_model( ) elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: + print(f">>> model_type: {model_type}") if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -856,7 +860,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashCausalLM( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 87b9969a..212ab7a9 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -233,7 +233,7 @@ class CausalLMBatch(Batch): ] # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: + if type(self.past_key_values[0]) is tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection @@ -377,7 +377,7 @@ class CausalLMBatch(Batch): # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: + if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 77b89c5b..e2719fad 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel): loss = None if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] + output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return ( diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 27b9ff1c..ab824da5 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from torch import nn @@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, ) from transformers.modeling_outputs import ( - BaseModelOutput, BaseModelOutputWithPooling, - ImageClassifierOutput, ) from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig @@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module): class CLIPTextTransformer(nn.Module): - def __init__(self, prefix: str, config: CLIPTextConfig): + def __init__(self, prefix: str, config: CLIPTextConfig, weights=None): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = CLIPTextEmbeddings(config) + # Initialize weights and apply final processing with `self.post_init()` self.encoder = CLIPEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) @@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module): # text_embeds.shape = [batch_size, sequence_length, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ + last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device ), @@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module): ] else: # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) - pooled_output = last_hidden_state[ + last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device ), @@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) return self.text_model( input_ids=input_ids, @@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module): def __init__(self, prefix, config: CLIPVisionConfig, weights): super().__init__() self.config = config - embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights @@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) return self.vision_model( pixel_values=pixel_values, @@ -799,14 +791,12 @@ class CLIPModel(nn.Module): # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. vision_outputs = self.vision_model( pixel_values=pixel_values, - return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - return_dict=return_dict, ) image_embeds = vision_outputs[1] diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index ab10dee1..b73dada6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,7 +30,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 7426fc55..3c301c8d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.log import log_once class DbrxAttentionConfig(PretrainedConfig): diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index f5b2ba0e..3e84b4a8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -39,6 +39,12 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + class DeepseekV2Config(PretrainedConfig): def __init__( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3e8e67ab..b55ddc23 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import ( - UnquantizedWeight, Weights, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader @@ -277,7 +276,7 @@ class LlamaMLP(nn.Module): bias=bias, ) else: - prefixes = [f"gate_proj", f"up_proj"] + prefixes = ["gate_proj", "up_proj"] sizes = [ config.intermediate_size, config.intermediate_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index eca01bbb..3d8f6bf4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( - Seqlen, paged_attention, attention, reshape_and_cache, @@ -38,7 +37,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - get_linear, TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index aae6737e..7cdca553 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -21,7 +21,6 @@ import torch import torch.distributed -import numpy as np from torch import nn from text_generation_server.utils.import_utils import SYSTEM @@ -31,7 +30,6 @@ if SYSTEM != "ipex": from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from loguru import logger from text_generation_server.layers.attention import ( paged_attention, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 1f998e5a..d10efb41 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -16,7 +16,6 @@ import torch import torch.distributed from torch import nn -from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index f40d126b..e357a287 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -15,7 +15,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - get_linear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index f540b7af..7e4deaf8 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch Idefics2 model.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.utils.checkpoint @@ -22,10 +22,8 @@ from torch import nn import math from transformers.activations import ACT2FN -from transformers.image_processing_utils import select_best_resolution from text_generation_server.models.custom_modeling.vlm import ( load_text_model, - load_vision_model, ) from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py index e323d365..afb8e1f9 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -19,6 +19,7 @@ import numpy as np from PIL import Image +import transformers from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( resize, @@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor): return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) -import transformers - transformers.IdeficsImageProcessor = IdeficsImageProcessor diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 786ef559..771bc234 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -21,10 +21,8 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.activations import ACT2FN @@ -33,13 +31,6 @@ from transformers.modeling_outputs import ( CausalLMOutputWithPast, dataclass, ) -from transformers.modeling_utils import PretrainedConfig -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_vision import ( IdeficsVisionTransformer, @@ -56,6 +47,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.import_utils import SYSTEM +from loguru import logger if SYSTEM == "cuda": import dropout_layer_norm @@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module): prefix="model.embed_tokens", weights=weights ) self.additional_weight = nn.Parameter( - weights.get_tensor(f"model.embed_tokens.additional_embedding.weight") + weights.get_tensor("model.embed_tokens.additional_embedding.weight") ) def forward(self, input_ids): @@ -499,7 +491,6 @@ class IdeficsAttention(nn.Module): # if not hasattr(nn.functional, "scaled_dot_product_attention"): # raise ValueError("this model requires pytorch 2.0 or higher") - process_group = weights.process_group if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -1024,7 +1015,7 @@ class IdeficsModel(IdeficsPreTrainedModel): if config.use_resampler: perceiver_config = config.perceiver_config self.perceiver_resampler = IdeficsPerceiverResampler( - prefix=f"model.perceiver_resampler", + prefix="model.perceiver_resampler", config=config, embed_dim=config.vision_config.embed_dim, depth=perceiver_config.resampler_depth, @@ -1052,7 +1043,7 @@ class IdeficsModel(IdeficsPreTrainedModel): # self.gradient_checkpointing = False self.norm = IdeficsRMSNorm( - prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps + prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) # self.gradient_checkpointing = False diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py index af44490b..6da8045b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py +++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module): self.qk_scale = self.head_dim**-0.5 - process_group = weights.process_group if n_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " diff --git a/server/text_generation_server/models/custom_modeling/idefics_processing.py b/server/text_generation_server/models/custom_modeling/idefics_processing.py index 7bba6977..ca61e27d 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_processing.py @@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import ( TruncationStrategy, ) from transformers.utils import TensorType, is_torch_available -from text_generation_server.models.custom_modeling.idefics_image_processing import ( - IdeficsImageProcessor, -) if is_torch_available(): diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py index 30c5997f..dd8f76bc 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module): self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - process_group = weights.process_group if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config - embed_dim = config.hidden_size self.embeddings = IdeficsVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index e154d805..29f5b9c7 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.utils.checkpoint diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 04c547b2..988a74a3 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math -import os import warnings from typing import List, Optional, Tuple, Union import torch @@ -194,7 +193,7 @@ def flash_attn_fn( ): try: from flash_attn import bert_padding, flash_attn_interface - except: + except Exception: raise RuntimeError("Please install flash-attn==1.0.3.post0") check_valid_inputs(query, key, value) if past_key_value is not None: @@ -207,7 +206,7 @@ def flash_attn_fn( _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: - raise NotImplementedError(f"attn_bias not implemented for flash attn.") + raise NotImplementedError("attn_bias not implemented for flash attn.") (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) @@ -269,13 +268,13 @@ def triton_flash_attn_fn( ): try: from .flash_attn_triton import flash_attn_func - except: + except Exception: _installed = False if version.parse(torch.__version__) < version.parse("2.0.0"): _installed = True try: from flash_attn.flash_attn_triton import flash_attn_func - except: + except Exception: _installed = False if not _installed: raise RuntimeError( @@ -292,9 +291,9 @@ def triton_flash_attn_fn( _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: - raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") + raise NotImplementedError("Dropout not implemented for attn_impl: triton.") if needs_weights: - raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") + raise NotImplementedError("attn_impl: triton cannot return attn weights.") if key_padding_mask is not None: warnings.warn( "Propagating key_padding_mask to the attention module " @@ -428,7 +427,7 @@ class MultiQueryAttention(nn.Module): additive bias. """ - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix, weights, verbose=False): super().__init__() attn_impl = config.attn_config.attn_impl self.attn_impl = config.attn_config.attn_impl @@ -445,7 +444,7 @@ class MultiQueryAttention(nn.Module): self.Wqkv = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) - fuse_splits = (d_model, d_model + self.head_dim) + (d_model, d_model + self.head_dim) if self.qk_ln: raise NotImplementedError("qk_ln not supported") if self.attn_impl == "flash": @@ -795,7 +794,9 @@ class MPTModel(MPTPreTrainedModel): self.alibi = config.attn_config.alibi self.alibi_bias_max = config.attn_config.alibi_bias_max if config.init_device == "mixed": - if dist.get_local_rank() == 0: + # TODO: reimplement mixed device initialization + # dist.get_local_rank() == 0: + if True: config.init_device = "cpu" else: config.init_device = "meta" @@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel): if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError( - f"past_key_values must provide a past_key_value for each attention " + "past_key_values must provide a past_key_value for each attention " + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." ) past_position = past_key_values[0][0].size(1) @@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel): input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) - if kwargs.get("use_cache") == False: + if kwargs.get("use_cache") is False: raise NotImplementedError( "MPT with prefix_lm=True does not support use_cache=False." ) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 8998778f..06731a6f 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -21,25 +21,14 @@ import torch import torch.distributed import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN -from transformers.file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel -from transformers import GPTNeoXConfig -from loguru import logger from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) - max_positions = config.max_position_embeddings # ??? TODO # self.register_buffer( # "bias", diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index b4d56db1..3f2ed010 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -5,7 +5,7 @@ import torch.distributed import math from torch import nn -from typing import Optional, List, Tuple, Any +from typing import Optional, List, Tuple from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 5fbc6d29..480d0f9f 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -1,20 +1,15 @@ -from typing import Optional, Tuple, Union - +from typing import Optional, Tuple +import warnings import math import torch from torch import nn from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import ( - _create_4d_causal_attention_mask, - _prepare_4d_attention_mask, -) from transformers.modeling_outputs import ( - BaseModelOutput, BaseModelOutputWithPooling, - ImageClassifierOutput, ) -from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +from transformers import SiglipConfig, SiglipVisionConfig +from torch.nn.init import _calculate_fan_in_and_fan_out from text_generation_server.layers.tensor_parallel import ( TensorParallelEmbedding, @@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -import warnings - - def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * lower - 1, 2 * upper - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal @@ -313,9 +305,6 @@ def trunc_normal_tf_( tensor.mul_(std).add_(mean) -from torch.nn.init import _calculate_fan_in_and_fan_out - - def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": @@ -349,9 +338,6 @@ def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") -from transformers import PreTrainedModel - - class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module): def __init__(self, prefix, config: SiglipVisionConfig, weights): super().__init__() self.config = config - embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 0b899fba..e6666acd 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -45,6 +45,15 @@ from text_generation_server.layers import ( SpeculativeHead, ) +# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + class PartialTPEmbedding(nn.Module): def __init__(self, prefix: str, weights): @@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) # move labels to correct device to enable PP - labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + output = (logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return ( diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index b74b43ff..e5c44045 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights): ) return SiglipVisionTransformer( - prefix=f"vision_tower.vision_model", config=config, weights=weights + prefix="vision_tower.vision_model", config=config, weights=weights ) else: raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2ca40959..152516e7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1194,7 +1194,7 @@ class FlashCausalLM(Model): if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: - logger.exception(f"Decode cuda graph warmup failed") + logger.exception("Decode cuda graph warmup failed") else: log_master( logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 2d43244a..7c4e462c 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -2,23 +2,15 @@ import re import torch import torch.distributed -from typing import List, Optional, Type from transformers import ( - AutoTokenizer, - AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, - initialize_torch_distributed, - weight_files, - Weights, ) from text_generation_server.utils.chunks import concat_text_chunks diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 0deab6ce..af224254 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -1,13 +1,8 @@ import torch import torch.distributed -from typing import List, Optional, Tuple +from typing import Optional -from transformers import ( - AutoTokenizer, - AutoConfig, - AutoProcessor, -) from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_processing import ( diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 6c562980..8a80ed68 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch): image_hidden_states = self.image_hidden_states[keep_indices] # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: + if type(self.past_key_values[0]) is tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection @@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch): # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: + if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 4ed9722c..5d6ce3c7 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -2,7 +2,6 @@ import torch import torch.distributed from transformers import AutoTokenizer, PreTrainedTokenizerBase from typing import Optional -import os from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaConfig, ) @@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import ( InferenceParams, ) from text_generation_server.models import Model -from typing import Any, List, Optional, Tuple, Type, Dict +from typing import Any, List, Tuple, Type, Dict from text_generation_server.models.types import ( Batch, Tokens, @@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import NextTokenChooser, StoppingCriteria def new_inference_params( @@ -299,7 +298,6 @@ class MambaBatch(Batch): stopping_criterias = [] top_n_tokens = [] max_tokens = 0 - max_seqlen = 0 seqlen_offset = 0 (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape @@ -485,7 +483,7 @@ class Mamba(Model): for bs in CUDA_GRAPHS: self.cuda_graph_warmup(bs) except Exception: - logger.exception(f"Decode cuda graph warmup failed") + logger.exception("Decode cuda graph warmup failed") else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") @@ -534,7 +532,7 @@ class Mamba(Model): } self.cuda_graphs[batch_size] = graph_dict - def tunableop_warmup(self, seqlen: int): + def tunableop_warmup(self, batch_size: int, seqlen: int): input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) n_blocks = len(self.model.blocks) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 159139de..20402e07 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -2,7 +2,7 @@ import inspect import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict +from typing import List, Tuple, Optional, TypeVar, Type, Dict from collections import defaultdict from transformers import PreTrainedTokenizerBase diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index 3994ac70..fe75570e 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -3,16 +3,11 @@ from PIL import Image import torch import torch.distributed from opentelemetry import trace -from typing import Iterable, Optional, Tuple +from typing import Iterable from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLM, VlmCausalLMBatch, image_text_replacement, ) -from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( - PaliGemmaForConditionalGeneration, -) -from transformers import AutoProcessor, AutoConfig from text_generation_server.pb.generate_pb2 import Request diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index fa8b5025..79c001b0 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,7 +1,6 @@ import torch import torch.distributed import time - from dataclasses import dataclass from opentelemetry import trace from transformers import ( @@ -11,7 +10,7 @@ from transformers import ( AutoConfig, ) from typing import Optional, Tuple, List, Type, Dict - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch): ] # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: + if type(self.past_key_values[0]) is tuple: self.past_key_values = [ [t for t in layer] for layer in self.past_key_values ] @@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch): batch.encoder_last_hidden_state = None # Ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: + if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ [t for t in layer] for layer in batch.past_key_values ] diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 339b733b..d4e7cca7 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -1,4 +1,3 @@ -from functools import total_ordering import torch from abc import ABC, abstractmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7ac54603..22bd759f 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -9,7 +9,7 @@ from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path -from typing import List, Optional, Dict +from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 011e0f63..7c053014 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,15 +1,13 @@ import torch from loguru import logger -import subprocess import os +import importlib.util + + def is_ipex_available(): - try: - import intel_extension_for_pytorch - except ImportError: - return False - return True + return importlib.util.find_spec("intel_extension_for_pytorch") is not None def get_cuda_free_memory(device, memory_fraction): diff --git a/server/text_generation_server/utils/merges/strategies.py b/server/text_generation_server/utils/merges/strategies.py index 3b885313..cb39cde1 100644 --- a/server/text_generation_server/utils/merges/strategies.py +++ b/server/text_generation_server/utils/merges/strategies.py @@ -2,9 +2,17 @@ import copy from abc import ABC from collections import defaultdict from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union - +from text_generation_server.utils.merges.utils import ( + calculate_majority_sign_mask, + disjoint_merge, + prune, +) import torch +if TYPE_CHECKING: + from text_generation_server.adapters.lora import LoraConfig + from text_generation_server.utils.adapter import ModuleMap + class AdapterParameters: def __init__( @@ -17,17 +25,6 @@ class AdapterParameters: self.majority_sign_method = majority_sign_method -from text_generation_server.utils.merges.utils import ( - calculate_majority_sign_mask, - disjoint_merge, - prune, -) - -if TYPE_CHECKING: - from text_generation_server.adapters.lora import LoraConfig - from text_generation_server.utils.adapter import ModuleMap - - def _apply_weights( tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor ) -> torch.Tensor: diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 0ea89267..d49e73f0 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): low_cpu_mem_usage=True, ) logger.info("Peft model detected.") - logger.info(f"Merging the lora weights.") + logger.info("Merging the lora weights.") base_model_id = model.peft_config["default"].base_model_name_or_path diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 98dae079..cfb8b1db 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -6,7 +6,6 @@ from typing import Optional from huggingface_hub import hf_hub_download from text_generation_server.utils.weights import ( DefaultWeightsLoader, - UnquantizedWeight, WeightsLoader, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 22f86b60..9ab49665 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,7 +1,6 @@ import re from typing import List, Optional, Tuple, Set, Union -import math import torch from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType From 2c1d280faed14c8c6ff884f8d953875d326334d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Mon, 29 Jul 2024 11:14:17 +0200 Subject: [PATCH 171/340] Run ci api key (#2315) * Add API_Key for Auth and conditionally add authorisation for non info/health endpoints. * change name to info routes * Fix comment * convert strings to lowercase for case insensitive comparison * convert header to string * fixes and update docs * update docs again * revert wrong update --------- Co-authored-by: Kevin Duffy --- docs/source/basic_tutorials/launcher.md | 6 ++++ launcher/src/main.rs | 9 ++++++ router/src/main.rs | 4 +++ router/src/server.rs | 37 ++++++++++++++++++++++--- 4 files changed, 52 insertions(+), 4 deletions(-) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 77f88490..ce98876f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -349,6 +349,12 @@ Options: --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] +``` +## API_KEY +```shell + --api-key + [env: API_KEY=] + ``` ## WATERMARK_GAMMA ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3d3b068e..15233306 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -424,6 +424,10 @@ struct Args { #[clap(long, env)] cors_allow_origin: Vec, + + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] watermark_gamma: Option, #[clap(long, env)] @@ -1275,6 +1279,11 @@ fn spawn_webserver( router_args.push(origin); } + // API Key + if let Some(api_key) = args.api_key { + router_args.push("--api-key".to_string()); + router_args.push(api_key); + } // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index bfc77913..36879aa4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -77,6 +77,8 @@ struct Args { #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] + api_key: Option, + #[clap(long, env)] ngrok: bool, #[clap(long, env)] ngrok_authtoken: Option, @@ -127,6 +129,7 @@ async fn main() -> Result<(), RouterError> { otlp_endpoint, otlp_service_name, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, @@ -446,6 +449,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, addr, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, diff --git a/router/src/server.rs b/router/src/server.rs index c56c39a3..0fd5aade 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -37,6 +37,7 @@ use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; @@ -1417,6 +1418,7 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, + api_key: Option, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, @@ -1810,16 +1812,42 @@ pub async fn run( let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes - let base_routes = Router::new() + let mut base_routes = Router::new() .route("/", post(compat_generate)) - .route("/", get(health)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + if let Some(api_key) = api_key { + let mut prefix = "Bearer ".to_string(); + prefix.push_str(&api_key); + + // Leak to allow FnMut + let api_key: &'static str = prefix.leak(); + + let auth = move |headers: HeaderMap, + request: axum::extract::Request, + next: axum::middleware::Next| async move { + match headers.get(AUTHORIZATION) { + Some(token) => match token.to_str() { + Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + }, + None => Err(StatusCode::UNAUTHORIZED), + } + }; + + base_routes = base_routes.layer(axum::middleware::from_fn(auth)) + } + let info_routes = Router::new() + .route("/", get(health)) + .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)); @@ -1838,6 +1866,7 @@ pub async fn run( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) + .merge(info_routes) .merge(aws_sagemaker_route); #[cfg(feature = "google")] From 23a3927eb6d84518cb8f9d072fa5e4f402d19f18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 29 Jul 2024 15:37:10 +0200 Subject: [PATCH 172/340] Install Marlin from standalone package (#2320) --- server/marlin/COPYRIGHT | 20 - server/marlin/marlin_kernels/__init__.pyi | 76 - .../marlin_kernels/awq_marlin_repack.cu | 269 -- server/marlin/marlin_kernels/ext.cpp | 16 - server/marlin/marlin_kernels/ext.hh | 39 - server/marlin/marlin_kernels/fp8_marlin.cu | 1305 ---------- server/marlin/marlin_kernels/gptq_marlin.cu | 2195 ----------------- .../marlin_kernels/gptq_marlin_repack.cu | 344 --- server/marlin/marlin_kernels/marlin.cuh | 87 - .../marlin_kernels/marlin_cuda_kernel.cu | 1138 --------- .../marlin/marlin_kernels/marlin_dtypes.cuh | 79 - server/marlin/marlin_kernels/py.typed | 0 .../marlin_kernels/sparse/common/base.h | 51 - .../marlin/marlin_kernels/sparse/common/mem.h | 136 - .../marlin/marlin_kernels/sparse/common/mma.h | 191 -- .../sparse/marlin_24_cuda_kernel.cu | 1125 --------- server/marlin/setup.py | 24 - server/poetry.lock | 71 +- server/pyproject.toml | 8 + .../layers/marlin/gptq.py | 1 + 20 files changed, 79 insertions(+), 7096 deletions(-) delete mode 100644 server/marlin/COPYRIGHT delete mode 100644 server/marlin/marlin_kernels/__init__.pyi delete mode 100644 server/marlin/marlin_kernels/awq_marlin_repack.cu delete mode 100644 server/marlin/marlin_kernels/ext.cpp delete mode 100644 server/marlin/marlin_kernels/ext.hh delete mode 100644 server/marlin/marlin_kernels/fp8_marlin.cu delete mode 100644 server/marlin/marlin_kernels/gptq_marlin.cu delete mode 100644 server/marlin/marlin_kernels/gptq_marlin_repack.cu delete mode 100644 server/marlin/marlin_kernels/marlin.cuh delete mode 100644 server/marlin/marlin_kernels/marlin_cuda_kernel.cu delete mode 100644 server/marlin/marlin_kernels/marlin_dtypes.cuh delete mode 100644 server/marlin/marlin_kernels/py.typed delete mode 100644 server/marlin/marlin_kernels/sparse/common/base.h delete mode 100644 server/marlin/marlin_kernels/sparse/common/mem.h delete mode 100644 server/marlin/marlin_kernels/sparse/common/mma.h delete mode 100644 server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu delete mode 100644 server/marlin/setup.py diff --git a/server/marlin/COPYRIGHT b/server/marlin/COPYRIGHT deleted file mode 100644 index 69f3b8e6..00000000 --- a/server/marlin/COPYRIGHT +++ /dev/null @@ -1,20 +0,0 @@ -These kernels were vendored from VLLM. The Marlin kernels were developed -by Elias Frantar and extended by Neural Magic. - ---- - -Copyright (C) Marlin.2024 Elias Frantar -Modified by Neural Magic -Copyright 2024 The vLLM team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi deleted file mode 100644 index 53464719..00000000 --- a/server/marlin/marlin_kernels/__init__.pyi +++ /dev/null @@ -1,76 +0,0 @@ -import torch - -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. This is an extension of - `marlin_gemm` that supports converted GPTQ kernels. - """ - ... - -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. This is an extension of - `marlin_gemm` that supports 2:4 sparsity. - """ - ... - -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - """Repack GPTQ parameters for Marlin kernels.""" - ... - -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. - """ - ... - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return torch.ops._C.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) diff --git a/server/marlin/marlin_kernels/awq_marlin_repack.cu b/server/marlin/marlin_kernels/awq_marlin_repack.cu deleted file mode 100644 index c58216d8..00000000 --- a/server/marlin/marlin_kernels/awq_marlin_repack.cu +++ /dev/null @@ -1,269 +0,0 @@ -#include "marlin.cuh" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -namespace marlin { - -template -__global__ void awq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace marlin - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -namespace marlin { - -template -__global__ void awq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) { - constexpr int pack_factor = 32 / num_bits; - - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; - int block_k_tiles = div_ceil(k_tiles, gridDim.x); - - int start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) { - return; - } - - int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - extern __shared__ int4 sh[]; - - constexpr int tile_n_ints = tile_n_size / pack_factor; - - constexpr int stage_n_threads = tile_n_ints / 4; - constexpr int stage_k_threads = tile_k_size; - constexpr int stage_size = stage_k_threads * stage_n_threads; - - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - cp_async_fence(); - return; - } - - int first_n = n_tile_id * tile_n_size; - int first_n_packed = first_n / pack_factor; - - int4* sh_ptr = sh + stage_size * pipe; - - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - - cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + - first_n_packed + (n_id * 4)]))); - } - - cp_async_fence(); - }; - - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - return; - } - - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; - - if (warp_id >= 4) { - return; - } - - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; - - constexpr int tc_offsets[4] = {0, 1, 8, 9}; - - int cur_n = warp_id * 16 + tc_col; - int cur_n_packed = cur_n / pack_factor; - int cur_n_pos = cur_n % pack_factor; - - constexpr int sh_stride = tile_n_ints; - constexpr uint32_t mask = (1 << num_bits) - 1; - - int4* sh_stage_ptr = sh + stage_size * pipe; - uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - - // Undo interleaving - int cur_n_pos_unpacked; - if constexpr (num_bits == 4) { - constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } else { - constexpr int undo_pack[4] = {0, 2, 1, 3}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } - - uint32_t vals[8]; - #pragma unroll - for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; - - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + - sh_stride * cur_elem]; - - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; - } - - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - - uint32_t res = 0; - #pragma unroll - for (int i = 0; i < 8; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } - - out_ptr[out_offset + th_id * 4 + warp_id] = res; - - } else { - constexpr int pack_idx[4] = {0, 2, 1, 3}; - - uint32_t res1 = 0; - uint32_t res2 = 0; - #pragma unroll - for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } - - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } - }; - - auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } - - wait_for_stage(); - }; - #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { - int n_tile_id = 0; - - start_pipes(k_tile_id, n_tile_id); - - while (n_tile_id < n_tiles) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, - n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); - } - n_tile_id += repack_stages; - } - } -} - -} // namespace marlin - - #define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - marlin::awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::awq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, out_ptr, size_k, size_n); \ - } - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits) { - // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", marlin::tile_k_size); - TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", marlin::tile_n_size); - - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / num_bits; - - // Verify B - TORCH_CHECK(b_q_weight.size(0) == size_k, - "b_q_weight.size(0) = ", b_q_weight.size(0), - " is not size_k = ", size_k); - TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1), - "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), - ", size_n = ", size_n, ", pack_factor = ", pack_factor); - - // Verify device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); - - // Get ptrs - uint32_t const* b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - int dev = b_q_weight.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (false) { - } - CALL_IF(4) - CALL_IF(8) - else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); - } - - return out; -} - -#endif diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp deleted file mode 100644 index 8dcd2ff9..00000000 --- a/server/marlin/marlin_kernels/ext.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -#include "ext.hh" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("awq_marlin_repack", &awq_marlin_repack, - "Repack AWQ parameters for Marlin"); - m.def("gptq_marlin_gemm", &gptq_marlin_gemm, - "Marlin gemm with GPTQ compatibility"); - m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); - m.def("gptq_marlin_repack", &gptq_marlin_repack, - "Repack GPTQ parameters for Marlin"); - m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); - // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. - m.def("fp8_marlin_gemm", &fp8_marlin_gemm); -} diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh deleted file mode 100644 index ec76b5f4..00000000 --- a/server/marlin/marlin_kernels/ext.hh +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else - -torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits); - -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &b_zeros, - torch::Tensor &g_idx, torch::Tensor &perm, - torch::Tensor &workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp); - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k); - -#endif diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu deleted file mode 100644 index f6458ca6..00000000 --- a/server/marlin/marlin_kernels/fp8_marlin.cu +++ /dev/null @@ -1,1305 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "marlin.cuh" -#include "marlin_dtypes.cuh" - -using namespace marlin; - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace fp8_marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to BF16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to bfloat162 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - int slice_k_start = tb_k * slice_row; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We scale a `half2` tile in row-major layout for column-wise quantization. - int s_sh_rd = - 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - - thread_block_reduce(); - - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ - locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, - int group_size) { - int tb_n = th_config.thread_n; - - // Get max scale groups per thread-block - // Fixed for channelwise - int tb_groups = 1; - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, - prob_k, num_bits, group_size); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, - group_size, max_shared_mem); - } - - TORCH_CHECK( - exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, - ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = -1; - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - // Define kernel configurations - if (false) { - } - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - // Channelwise only for FP8 - TORCH_CHECK(b_scales.size(0) == 1) - num_groups = b_scales.size(0); - - // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), size_m, - size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, - dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else { - TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu deleted file mode 100644 index 122c5c16..00000000 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ /dev/null @@ -1,2195 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "marlin.cuh" -#include "marlin_dtypes.cuh" - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) {} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -template -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_4bit(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Zero-point dequantizers - -template -__device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_4bit_zp( - int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_4bit_zp(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template -__device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit_zp( - int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit_zp(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388608.f; - fp32_intermediates[1] -= 8388608.f; - fp32_intermediates[2] -= 8388608.f; - fp32_intermediates[3] -= 8388608.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; - - scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / default_threads; - int rest = size_k % default_threads; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += default_threads; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - div_ceil(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - if constexpr (!has_zp) { - return; - } - - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - if constexpr (num_bits == 4) { - int zp_quant = frag_qzp[k % 2][0]; - int zp_quant_shift = zp_quant >> 8; - frag_zp_0 = dequant_4bit_zp(zp_quant); - frag_zp_1 = dequant_4bit_zp(zp_quant_shift); - - } else { - int zp_quant_0 = frag_qzp[k % 2][0]; - int zp_quant_1 = frag_qzp[k % 2][1]; - frag_zp_0 = dequant_8bit_zp(zp_quant_0); - frag_zp_1 = dequant_8bit_zp(zp_quant_1); - } - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - if constexpr (has_zp) { - frag_b0 = dequant_4bit_zp(b_quant); - frag_b1 = dequant_4bit_zp(b_quant_shift); - - } else { - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - } - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - if constexpr (has_zp) { - frag_b0 = dequant_8bit_zp(b_quant_0); - frag_b1 = dequant_8bit_zp(b_quant_1); - } else { - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - } - - // Apply zero-point to frag_b0 - if constexpr (has_zp) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { - res = __hmul2(res, s[0]); - } - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float( - reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float( - reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( - reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - - int tb_n = th_config.thread_n; - int tb_k = th_config.thread_k; - - // Get max scale groups per thread-block - int tb_groups; - if (group_size == -1) { - tb_groups = 1; - } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size - } else { - tb_groups = div_ceil(tb_k, group_size); - } - - if (cache_scales_chunk) { - int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; - - } else { - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; - } -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - - #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, - void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - const int4* zp_ptr = (const int4*)zp; - const int* g_idx_ptr = (const int*)g_idx; - const int* perm_ptr = (const int*)perm; - int4* a_tmp_ptr = (int4*)a_tmp; - - int* locks = (int*)workspace; - - if (has_act_order) { - // Permute A columns - int block_rows = div_ceil(prob_m, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by having - // a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - if (false) { - } - GPTQ_CALL_IF(4, 16, 4, 256) - GPTQ_CALL_IF(4, 8, 8, 256) - GPTQ_CALL_IF(4, 8, 4, 128) - GPTQ_CALL_IF(4, 4, 8, 128) - GPTQ_CALL_IF(8, 16, 4, 256) - GPTQ_CALL_IF(8, 8, 8, 256) - GPTQ_CALL_IF(8, 8, 4, 128) - GPTQ_CALL_IF(8, 4, 8, 128) - - AWQ_CALL_IF(4, 16, 4, 256) - AWQ_CALL_IF(4, 8, 8, 256) - AWQ_CALL_IF(4, 8, 4, 128) - AWQ_CALL_IF(4, 4, 8, 128) - AWQ_CALL_IF(8, 16, 4, 256) - AWQ_CALL_IF(8, 8, 8, 256) - AWQ_CALL_IF(8, 8, 4, 128) - AWQ_CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, - ", ", prob_k, "]", ", has_act_order = ", has_act_order, - ", num_groups = ", num_groups, ", group_size = ", group_size, - ", thread_m_blocks = ", thread_m_blocks, - ", thread_n_blocks = ", thread_n_blocks, - ", thread_k_blocks = ", thread_k_blocks, - ", num_bits = ", num_bits); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); - TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); - - TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); - TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Verify g_idx and perm - TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || - (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = ", g_idx.size(0), - " and perm.size(0) = ", perm.size(0), - ", where size_k = ", size_k); - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(0) != 0; - - int rank = b_scales.sizes().size(); - TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - num_groups = b_scales.size(0); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - // Verify b_zeros - if (has_zp) { - int rank = b_zeros.sizes().size(); - TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); - TORCH_CHECK(b_zeros.size(0) == num_groups, - "b_zeros dim 0 = ", b_zeros.size(0), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_scales.size(1), - " is not size_n / pack_factor = ", size_n / pack_factor); - } - - // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin_repack.cu b/server/marlin/marlin_kernels/gptq_marlin_repack.cu deleted file mode 100644 index c71b1bf5..00000000 --- a/server/marlin/marlin_kernels/gptq_marlin_repack.cu +++ /dev/null @@ -1,344 +0,0 @@ -#include "marlin.cuh" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -namespace marlin { - -template -__global__ void gptq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace marlin - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -namespace marlin { - -template -__global__ void gptq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) { - constexpr int pack_factor = 32 / num_bits; - - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; - int block_k_tiles = div_ceil(k_tiles, gridDim.x); - - int start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) { - return; - } - - int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - extern __shared__ int4 sh[]; - - constexpr int perm_size = tile_k_size / 4; - - int4* sh_perm_ptr = sh; - int4* sh_pipe_ptr = sh_perm_ptr; - if constexpr (has_perm) { - sh_pipe_ptr += perm_size; - } - - constexpr int tile_ints = tile_k_size / pack_factor; - - constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; - constexpr int stage_size = stage_k_threads * stage_n_threads; - - auto load_perm_to_shared = [&](int k_tile_id) { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; - - int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); - - if (threadIdx.x < perm_size) { - sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; - } - __syncthreads(); - }; - - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - cp_async_fence(); - return; - } - - int first_n = n_tile_id * tile_n_size; - - int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; - - if constexpr (has_perm) { - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - uint32_t const* sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); - - int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor; - - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( - b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); - } - - } else { - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor; - - cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + - first_n + (n_id * 4)]))); - } - } - - cp_async_fence(); - }; - - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - return; - } - - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; - - if (warp_id >= 4) { - return; - } - - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; - - constexpr int tc_offsets[4] = {0, 1, 8, 9}; - - int cur_n = warp_id * 16 + tc_col; - - constexpr int sh_stride = 64; - constexpr uint32_t mask = (1 << num_bits) - 1; - - int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - - uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - - uint32_t vals[8]; - - if constexpr (has_perm) { - for (int i = 0; i < 4; i++) { - int k_idx = tc_row + tc_offsets[i]; - - uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor; - - uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; - - uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; - - vals[i] = b1_cur_val; - vals[4 + i] = b2_cur_val; - } - - } else { - uint32_t b1_vals[tile_ints]; - uint32_t b2_vals[tile_ints]; - - #pragma unroll - for (int i = 0; i < tile_ints; i++) { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; - } - - #pragma unroll - for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; - int cur_int = cur_elem / pack_factor; - int cur_pos = cur_elem % pack_factor; - - vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; - } - } - - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - - uint32_t res = 0; - #pragma unroll - for (int i = 0; i < 8; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } - - out_ptr[out_offset + th_id * 4 + warp_id] = res; - - } else { - constexpr int pack_idx[4] = {0, 2, 1, 3}; - - uint32_t res1 = 0; - uint32_t res2 = 0; - #pragma unroll - for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } - - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } - }; - - auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } - - wait_for_stage(); - }; - #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { - int n_tile_id = 0; - - if constexpr (has_perm) { - load_perm_to_shared(k_tile_id); - } - - start_pipes(k_tile_id, n_tile_id); - - while (n_tile_id < n_tiles) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, - n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); - } - n_tile_id += repack_stages; - } - } -} - -} // namespace marlin - - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - marlin::gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::gptq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", marlin::tile_k_size); - TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", marlin::tile_n_size); - - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / num_bits; - - // Verify B - TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", pack_factor = ", pack_factor); - TORCH_CHECK(b_q_weight.size(1) == size_n, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not size_n = ", size_n); - - // Verify device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); - TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); - - // Detect if there is act_order - bool has_perm = perm.size(0) != 0; - - // Get ptrs - uint32_t const* b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); - uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - int dev = b_q_weight.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (false) { - } - CALL_IF(4, false) - CALL_IF(4, true) - CALL_IF(8, false) - CALL_IF(8, true) - else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, - ", has_perm = ", has_perm); - } - - return out; -} - -#endif diff --git a/server/marlin/marlin_kernels/marlin.cuh b/server/marlin/marlin_kernels/marlin.cuh deleted file mode 100644 index 74ccbac5..00000000 --- a/server/marlin/marlin_kernels/marlin.cuh +++ /dev/null @@ -1,87 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace marlin { - -// Marlin params - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int default_threads = 256; - -static constexpr int pipe_stages = - 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -// Repack params -static constexpr int repack_stages = 8; - -static constexpr int repack_threads = 256; - -static constexpr int tile_k_size = tile_size; -static constexpr int tile_n_size = tile_k_size * 4; - -// Helpers -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -#endif - -} // namespace marlin diff --git a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu deleted file mode 100644 index 37339b84..00000000 --- a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1138 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} diff --git a/server/marlin/marlin_kernels/marlin_dtypes.cuh b/server/marlin/marlin_kernels/marlin_dtypes.cuh deleted file mode 100644 index be06c09b..00000000 --- a/server/marlin/marlin_kernels/marlin_dtypes.cuh +++ /dev/null @@ -1,79 +0,0 @@ - -#ifndef _data_types_cuh -#define _data_types_cuh -#include "marlin.cuh" -#include -#include - -namespace marlin { - -template -class ScalarType {}; - -template <> -class ScalarType { - public: - using scalar_t = half; - using scalar_t2 = half2; - - // Matrix fragments for tensor core instructions; their precise layout is - // documented here: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - using FragZP = Vec; - - static __device__ float inline num2float(const half x) { - return __half2float(x); - } - - static __device__ half2 inline num2num2(const half x) { - return __half2half2(x); - } - - static __device__ half2 inline nums2num2(const half x1, const half x2) { - return __halves2half2(x1, x2); - } - - static __host__ __device__ half inline float2num(const float x) { - return __float2half(x); - } -}; - -template <> -class ScalarType { - public: - using scalar_t = nv_bfloat16; - using scalar_t2 = nv_bfloat162; - - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - using FragZP = Vec; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static __device__ float inline num2float(const nv_bfloat16 x) { - return __bfloat162float(x); - } - - static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { - return __bfloat162bfloat162(x); - } - - static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, - const nv_bfloat16 x2) { - return __halves2bfloat162(x1, x2); - } - - static __host__ __device__ nv_bfloat16 inline float2num(const float x) { - return __float2bfloat16(x); - } -#endif -}; - -} // namespace marlin - -#endif diff --git a/server/marlin/marlin_kernels/py.typed b/server/marlin/marlin_kernels/py.typed deleted file mode 100644 index e69de29b..00000000 diff --git a/server/marlin/marlin_kernels/sparse/common/base.h b/server/marlin/marlin_kernels/sparse/common/base.h deleted file mode 100644 index 16018d33..00000000 --- a/server/marlin/marlin_kernels/sparse/common/base.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace marlin_24 { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -template -struct ShapeBase { - static constexpr int M = M_, N = N_, K = K_; -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragM = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mem.h b/server/marlin/marlin_kernels/sparse/common/mem.h deleted file mode 100644 index 83e3578d..00000000 --- a/server/marlin/marlin_kernels/sparse/common/mem.h +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" - -namespace marlin_24 { -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void* smem_ptr, - const void* glob_ptr, - bool pred = true, - const bool zfill = false) { - const int BYTES = 16; - int src_in_bytes = (zfill ? 0 : BYTES); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); -} - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_m); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mma.h b/server/marlin/marlin_kernels/sparse/common/mma.h deleted file mode 100644 index b26505f7..00000000 --- a/server/marlin/marlin_kernels/sparse/common/mma.h +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" -#include - -namespace marlin_24 { - -// On CUDA earlier than 12.5, the ordered_metadata version of this instruction -// is not supported. On later versions of CUDA the version without ordered -// metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially -// | reduced performance on some future architectures -#if defined CUDA_VERSION && CUDA_VERSION >= 12050 - #define MMA_SP_INST \ - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#else - #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#endif - -// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, - const FragA& frag_b, FragC& frag_c, FragM& frag_m, - const int psel) { - const uint32_t* a0 = reinterpret_cast(&a_frag0); - const uint32_t* a1 = reinterpret_cast(&a_frag1); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* e = reinterpret_cast(&frag_m); - - float* c = reinterpret_cast(&frag_c); - if (psel == 0) { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } else { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, - float c3) { - uint2 r; - asm("{\n\t" - ".reg .f16 a, b, c, d; \n\t" - "cvt.rn.f16.f32 a, %2; \n\t" - "cvt.rn.f16.f32 b, %3; \n\t" - "cvt.rn.f16.f32 c, %4; \n\t" - "cvt.rn.f16.f32 d, %5; \n\t" - "mov.b32 %0, {a, b}; \n\t" - "mov.b32 %1, {c, d}; \n\t" - "}" - : "=r"(r.x), "=r"(r.y) - : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - return r; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, - FragS& s0, float* c4, float* c5, float* c6, - float* c7, FragS& s1) { - *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); - *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); - *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); - *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); - - *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); - *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); - *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); - *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); -} - -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu b/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu deleted file mode 100644 index b5effc30..00000000 --- a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu +++ /dev/null @@ -1,1125 +0,0 @@ -/* - * Notice: This file was modified by Neuralmagic inc to include 8-bit support - * - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -#else - - #include "common/mem.h" - #include "common/mma.h" - -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_24 { - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int THREADS = 256; -static constexpr int STAGES = 4; - -static constexpr int min_thread_n = 128; - -static constexpr int tile_size = 16; -static constexpr int max_par = 64; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - // number of thread_k_blocks in k-dim - int k_tiles = prob_k / 32 / thread_k_blocks; - // number of thread_n_blocks in n-dim - int n_tiles = prob_n / 16 / thread_n_blocks; - // iters needed to cover all slices - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - // number of threadblock tiles in the current slice - int slice_iters; - // total number of active threadblocks in the current slice - int slice_count = 0; - // index of threadblock in current slice; numbered bottom to top - int slice_idx; - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 32 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads //RLC: 2 * #warps k-dim - constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - constexpr int pack_factor = 32 / num_bits; - - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 - constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp - int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; - int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); - constexpr int m_sh_wr_delta = threads / 2; - constexpr int m_sh_rd_delta = threads / 2; - constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; - constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + - (threadIdx.x % (m_sh_stride)); - m_gl_rd += (m_sh_stride)*slice_col; - m_gl_rd += m_gl_rd_delta_o * slice_row; - int m_sh_wr = threadIdx.x; - int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; - - int s_gl_rd; - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } else { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) { - a_sh_rd_trans[0][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - a_sh_rd_trans[1][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); - } - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4* meta_ptr[m_sh_iters]; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_m = sh_s + (stages * s_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks][2]; - I4 frag_b_quant[2][b_thread_vecs]; - FragM frag_m[2][2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - int4* sh_meta_stage = sh_m + m_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) { - if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); - meta_ptr[i] += m_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - ldsm4(frag_a[k % 2][i][0], - &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); - ldsm4(frag_a[k % 2][i][1], - &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); - } - - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - - // Load meta with ldsm4 - int4* sh_m_stage = sh_m + m_sh_stage * pipe; - ldsm4_m(frag_m[k % 2][0], - &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], - frag_m[k % 2][j / 2], j % 2); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; - int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) - int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int col = 2 * ((threadIdx.x % 32) % 4); - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2] += - __half2float( - reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); - } - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2]); - } - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: - - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - - int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) - - constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: - int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + - (threadIdx.x % (2 * 2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, - float c4, float c5, float c6, float c7, FragS& s1) { - uint2 res[2]; - res[0] = to_half4(c0, c1, c2, c3); - res[1] = to_half4(c4, c5, c6, c7); - half2* tmp = (half2*)&res; - // for per-column quantization we finally apply the scale here - if constexpr (group_blocks == -1 && num_bits == 4) { - tmp[0] = __hmul2(tmp[0], s0[0]); - tmp[1] = __hmul2(tmp[1], s0[1]); - tmp[2] = __hmul2(tmp[2], s1[0]); - tmp[3] = __hmul2(tmp[3], s1[1]); - } - ((int4*)sh)[idx] = *((int4*)&res[0]); - }; - - // RLC: only warp 0 and 1 baseline example - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - int wr = c_sh_wr; - write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], - frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], - frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], - frag_s[0][2]); - write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], - frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], - frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], - frag_c[i][3][0][3], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], - frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], - frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], - frag_c[i][3][1][2], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], - frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], - frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], - frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); - - c_sh_wr += 8 * c_sh_stride_2; - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - matmul(pipe); - wait_for_stage(); - - fetch_to_registers(pipe + 1, (pipe + 1) % stages); - - pipe++; - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - } - } - thread_block_reduce(); - - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], - &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], - &frag_c[i][0][0][2], &frag_c[i][1][0][2], - &frag_c[i][2][0][2], &frag_c[i][3][0][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], - &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], - &frag_c[i][0][0][3], &frag_c[i][1][0][3], - &frag_c[i][2][0][3], &frag_c[i][3][0][3], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], - &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], - &frag_c[i][0][1][2], &frag_c[i][1][1][2], - &frag_c[i][2][1][2], &frag_c[i][3][1][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], - &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], - &frag_c[i][0][1][3], &frag_c[i][1][1][3], - &frag_c[i][2][1][3], &frag_c[i][3][1][3], - frag_s[0][2]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#endif - -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ - } - -void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, - void* s, int prob_m, int prob_n, int prob_k, - void* workspace, int num_bits, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_m = -1, int sms = -1, int max_par = 16) { - int tot_n = prob_n; - int tot_n_blocks = ceildiv(tot_n, 16); - int pad = 16 * tot_n_blocks - tot_n; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - TORCH_CHECK(sms > 0); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (thread_k == -1 || thread_m == -1) { - if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important - // than better compute utilization - thread_k = 128; - thread_m = 128; - } else if (prob_n <= 256) { - thread_k = 64; - thread_m = 256; - } else { - thread_k = 32; - thread_m = 512; - } - } - - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction - int thread_m_blocks = thread_m / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, - " is not divisible by thread_m = ", thread_m); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, - " is not divisible by group_blocks = ", group_blocks); - } - - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - const int4* meta_ptr = (const int4*)meta; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - constexpr int max_m_blocks = 4; - - int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { - int thread_n_blocks = tot_n_blocks - i; - prob_n = tot_n - 16 * i; - int par = 1; - if (thread_n_blocks > max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); - if (par > max_par) par = max_par; - prob_n = (max_m_blocks * 16) * par; - i += max_m_blocks * (par - 1); - thread_n_blocks = max_m_blocks; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - - // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group - // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 16, 2, 2, 4) - CALL_IF_2_4(4, 16, 3, 2, -1) - CALL_IF_2_4(4, 16, 3, 2, 4) - CALL_IF_2_4(4, 16, 4, 2, -1) - CALL_IF_2_4(4, 16, 4, 2, 4) - - CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 32, 2, 1, 4) - CALL_IF_2_4(4, 32, 3, 1, -1) - CALL_IF_2_4(4, 32, 3, 1, 4) - CALL_IF_2_4(4, 32, 4, 1, -1) - CALL_IF_2_4(4, 32, 4, 1, 4) - - // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 16, 2, 2, 4) - CALL_IF_2_4(8, 16, 3, 2, -1) - CALL_IF_2_4(8, 16, 3, 2, 4) - CALL_IF_2_4(8, 16, 4, 2, -1) - CALL_IF_2_4(8, 16, 4, 2, 4) - - CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 32, 2, 1, 4) - CALL_IF_2_4(8, 32, 3, 1, -1) - CALL_IF_2_4(8, 32, 3, 1, 4) - CALL_IF_2_4(8, 32, 4, 1, -1) - CALL_IF_2_4(8, 32, 4, 1, 4) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; - } -} - -} // namespace marlin_24 - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_24::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_24::tile_size)); - TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_24::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_24::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_24::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify meta - TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, - "b_meta.size(0) = ", b_meta.size(0), - " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); - TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), - " is not size_n * 2 = ", size_n * 2); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify b_meta device and strides - TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); - TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - int thread_k = -1; - int thread_m = -1; - int sms = -1; - int max_par = marlin_24::max_par; - - int groupsize = -1; - if (b_scales.size(0) > 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 - } - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 64, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_24::min_thread_n)); - int min_workspace_size = - (size_n / marlin_24::min_thread_n) * marlin_24::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_24::marlin_cuda_2_4( - a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_m, sms, max_par); - - return c; -} diff --git a/server/marlin/setup.py b/server/marlin/setup.py deleted file mode 100644 index f92eef04..00000000 --- a/server/marlin/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -extra_compile_args = [] - -setup( - name="marlin_kernels", - ext_modules=[ - CUDAExtension( - name="marlin_kernels", - sources=[ - "marlin_kernels/awq_marlin_repack.cu", - "marlin_kernels/fp8_marlin.cu", - "marlin_kernels/gptq_marlin.cu", - "marlin_kernels/gptq_marlin_repack.cu", - "marlin_kernels/marlin_cuda_kernel.cu", - "marlin_kernels/sparse/marlin_24_cuda_kernel.cu", - "marlin_kernels/ext.cpp", - ], - extra_compile_args=extra_compile_args, - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/server/poetry.lock b/server/poetry.lock index c993a441..5072aa0b 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1139,6 +1139,74 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" + [[package]] name = "mpmath" version = "1.3.0" @@ -3507,6 +3575,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"] outlines = ["outlines"] peft = ["peft"] quantize = ["accelerate", "datasets", "texttable"] @@ -3515,4 +3584,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "c94bbdf8131750891fb3f7132066718534129d85a4c09126d8d01c2de6c72798" +content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" diff --git a/server/pyproject.toml b/server/pyproject.toml index f8c3cb8e..da0d1526 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -40,10 +40,18 @@ py-cpuinfo = "^9.0.0" # Remove later, temporary workaround for outlines. numpy = "^1.26" +marlin-kernels = [ + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, +] + [tool.poetry.extras] torch = ["torch"] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +marlin = ["marlin-kernels"] peft = ["peft"] quantize = ["texttable", "datasets", "accelerate"] outlines = ["outlines"] diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 697634ea..80674b14 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module): A_flat.shape[1], self.is_full_k, self.qzeros.numel() > 0, + True, ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) From a574381cb489ee99644b7ebc33598ba8f87a651c Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 29 Jul 2024 10:07:25 -0400 Subject: [PATCH 173/340] fix: reject grammars without properties (#2309) --- router/src/validation.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/router/src/validation.rs b/router/src/validation.rs index 07ad14c9..1913a2ce 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -353,6 +353,14 @@ impl Validation { .compile(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + // The schema can be valid but lack properties. + // We need properties for the grammar to be successfully parsed in Python. + // Therefore, we must check and throw an error if properties are missing. + json.get("properties") + .ok_or(ValidationError::InvalidGrammar( + "Grammar must have a 'properties' field".to_string(), + ))?; + // Serialize json to string ValidGrammar::Json( serde_json::to_string(&json) From b1d1d265592ed7eca00ab86dc3ae3d6e1ef52c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Mon, 29 Jul 2024 16:09:25 +0200 Subject: [PATCH 174/340] patch-error-on-invalid-grammar (#2282) * quick fix * allow silent failure * explicit todo that this is only short term --- server/text_generation_server/utils/logits_process.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 6b915437..9abd886f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -521,7 +521,13 @@ class GrammarLogitProcessor(LogitsProcessor): def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: - schema = build_regex_from_schema(schema) + try: + schema = build_regex_from_schema(schema) + # TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer + except Exception as e: + logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}") + # allows everything + schema = "(.*?)" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer) From bafab73f761bab0cb69293e13d294c4bfd2052ad Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 29 Jul 2024 11:38:38 -0400 Subject: [PATCH 175/340] fix: adjust test snapshots and small refactors (#2323) * fix: adjust test snapshots and small refactors * fix: revert non snapshot changes --- .../test_bloom_560m/test_bloom_560m.json | 85 +++++----- .../test_bloom_560m_all_params.json | 55 ++++--- .../test_bloom_560m_sharded.json | 85 +++++----- ...t_flash_llama_completion_many_prompts.json | 18 +- .../test_flash_deepseek_v2_all_params.json | 54 +----- .../test_flash_deepseek_v2_load.json | 154 +++++++++--------- .../test_flash_gemma/test_flash_gemma.json | 34 ++-- .../test_flash_gemma_all_params.json | 20 +-- .../test_flash_gemma_gptq_all_params.json | 84 +++++----- .../test_flash_starcoder_default_params.json | 17 +- .../test_flash_starcoder2_default_params.json | 120 +++++++------- .../__snapshots__/test_mamba/test_mamba.json | 18 +- .../test_mamba/test_mamba_all_params.json | 6 +- 13 files changed, 359 insertions(+), 391 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json index 53a4ab85..b274992e 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json @@ -11,52 +11,52 @@ }, { "id": 49833, - "logprob": -10.5625, + "logprob": -10.5703125, "text": " dég" }, { "id": 21543, - "logprob": -0.14770508, + "logprob": -0.14746094, "text": "uster" }, { "id": 447, - "logprob": -1.9287109, + "logprob": -1.9277344, "text": " un" }, { "id": 46341, - "logprob": -15.4609375, + "logprob": -15.421875, "text": " ort" }, { "id": 35567, - "logprob": -7.5585938, + "logprob": -7.5820312, "text": "olan" }, { "id": 15, - "logprob": -1.4003906, + "logprob": -1.4013672, "text": "," }, { "id": 1669, - "logprob": -1.5673828, + "logprob": -1.5664062, "text": " il" }, { "id": 11580, - "logprob": -0.94628906, + "logprob": -0.94189453, "text": " faut" }, { "id": 3913, - "logprob": -3.703125, + "logprob": -3.6816406, "text": " tout" }, { "id": 39261, - "logprob": -1.5732422, + "logprob": -1.7753906, "text": " d'abord" } ], @@ -64,65 +64,66 @@ "tokens": [ { "id": 578, - "logprob": -1.6591797, + "logprob": -1.6318359, "special": false, "text": " le" }, { "id": 5608, - "logprob": -2.4492188, + "logprob": -2.4882812, "special": false, "text": " faire" }, { - "id": 159570, - "logprob": -6.6835938, + "id": 7735, + "logprob": -2.4355469, "special": false, - "text": " réch" + "text": " fond" }, { - "id": 810, + "id": 289, "logprob": 0.0, "special": false, - "text": "au" + "text": "re" }, { - "id": 12736, + "id": 693, + "logprob": -2.4472656, + "special": false, + "text": " à" + }, + { + "id": 366, + "logprob": -1.1972656, + "special": false, + "text": " la" + }, + { + "id": 48844, + "logprob": -1.7890625, + "special": false, + "text": " cass" + }, + { + "id": 1744, "logprob": 0.0, "special": false, - "text": "ffer" + "text": "ero" }, { - "id": 1742, - "logprob": -2.5175781, - "special": false, - "text": " au" - }, - { - "id": 6105, - "logprob": -2.0078125, - "special": false, - "text": " bain" - }, - { - "id": 88254, - "logprob": -0.12695312, - "special": false, - "text": "-mar" - }, - { - "id": 641, + "id": 327, "logprob": 0.0, "special": false, - "text": "ie" + "text": "le" }, { "id": 2940, - "logprob": -3.5175781, + "logprob": -1.9335938, "special": false, "text": " avec" } - ] + ], + "top_tokens": null }, - "generated_text": " le faire réchauffer au bain-marie avec" + "generated_text": " le faire fondre à la casserole avec" } diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json index ace73416..9422f27f 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json @@ -11,7 +11,7 @@ }, { "id": 1669, - "logprob": -5.4414062, + "logprob": -5.4453125, "text": " il" }, { @@ -21,12 +21,12 @@ }, { "id": 3913, - "logprob": -4.3554688, + "logprob": -4.3320312, "text": " tout" }, { "id": 39261, - "logprob": -2.9238281, + "logprob": -2.9160156, "text": " d'abord" } ], @@ -34,65 +34,66 @@ "tokens": [ { "id": 408, - "logprob": -0.07891846, + "logprob": -0.16687012, "special": false, "text": " que" }, { "id": 366, - "logprob": -1.2939453, + "logprob": -1.5517578, "special": false, "text": " la" }, { "id": 8769, - "logprob": -0.3708496, + "logprob": -0.16687012, "special": false, "text": " personne" }, { "id": 1479, - "logprob": -2.2871094, + "logprob": -2.1035156, "special": false, "text": " qui" }, { - "id": 2997, - "logprob": -0.8671875, + "id": 143926, + "logprob": -2.8671875, "special": false, - "text": " vous" + "text": " réalise" }, { - "id": 35977, - "logprob": -1.5097656, + "id": 578, + "logprob": 0.0, "special": false, - "text": " suit" + "text": " le" }, { - "id": 21558, - "logprob": -0.07891846, + "id": 8138, + "logprob": -0.66748047, "special": false, - "text": " ait" + "text": " projet" }, { - "id": 447, - "logprob": -0.12695312, + "id": 795, + "logprob": -1.6279297, "special": false, - "text": " un" + "text": " ne" }, { - "id": 78606, - "logprob": -2.21875, + "id": 9802, + "logprob": -0.47875977, "special": false, - "text": " profil" + "text": " soit" }, { - "id": 3899, - "logprob": -1.3535156, + "id": 1230, + "logprob": 0.0, "special": false, - "text": " bien" + "text": " pas" } - ] + ], + "top_tokens": null }, - "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien" + "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui réalise le projet ne soit pas" } diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json index dd8936af..b17c889e 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json @@ -11,52 +11,52 @@ }, { "id": 49833, - "logprob": -10.5390625, + "logprob": -10.546875, "text": " dég" }, { "id": 21543, - "logprob": -0.14758301, + "logprob": -0.14819336, "text": "uster" }, { "id": 447, - "logprob": -1.9296875, + "logprob": -1.9257812, "text": " un" }, { "id": 46341, - "logprob": -15.4453125, + "logprob": -15.4296875, "text": " ort" }, { "id": 35567, - "logprob": -7.59375, + "logprob": -7.5625, "text": "olan" }, { "id": 15, - "logprob": -1.3994141, + "logprob": -1.4199219, "text": "," }, { "id": 1669, - "logprob": -1.578125, + "logprob": -1.5634766, "text": " il" }, { "id": 11580, - "logprob": -0.9453125, + "logprob": -0.9458008, "text": " faut" }, { "id": 3913, - "logprob": -3.7011719, + "logprob": -3.6816406, "text": " tout" }, { "id": 39261, - "logprob": -1.5732422, + "logprob": -1.7753906, "text": " d'abord" } ], @@ -64,65 +64,66 @@ "tokens": [ { "id": 578, - "logprob": -1.6474609, + "logprob": -1.828125, "special": false, "text": " le" }, { "id": 5608, - "logprob": -2.5097656, + "logprob": -2.5546875, "special": false, "text": " faire" }, { - "id": 159570, - "logprob": -6.65625, + "id": 7735, + "logprob": -2.4277344, "special": false, - "text": " réch" + "text": " fond" }, { - "id": 810, + "id": 289, "logprob": 0.0, "special": false, - "text": "au" + "text": "re" }, { - "id": 12736, + "id": 693, + "logprob": -2.4472656, + "special": false, + "text": " à" + }, + { + "id": 366, + "logprob": -1.1494141, + "special": false, + "text": " la" + }, + { + "id": 48844, + "logprob": -1.7939453, + "special": false, + "text": " cass" + }, + { + "id": 1744, "logprob": 0.0, "special": false, - "text": "ffer" + "text": "ero" }, { - "id": 1742, - "logprob": -2.5859375, - "special": false, - "text": " au" - }, - { - "id": 6105, - "logprob": -2.03125, - "special": false, - "text": " bain" - }, - { - "id": 88254, - "logprob": -0.12695312, - "special": false, - "text": "-mar" - }, - { - "id": 641, + "id": 327, "logprob": 0.0, "special": false, - "text": "ie" + "text": "le" }, { "id": 2940, - "logprob": -3.5175781, + "logprob": -1.9013672, "special": false, "text": " avec" } - ] + ], + "top_tokens": null }, - "generated_text": " le faire réchauffer au bain-marie avec" + "generated_text": " le faire fondre à la casserole avec" } diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json index 99c33cf7..9f3faffc 100644 --- a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json @@ -1,11 +1,17 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 1, "logprobs": null, "text": " PR for more information?" }, + { + "finish_reason": "length", + "index": 3, + "logprobs": null, + "text": "hd20220811-" + }, { "finish_reason": "length", "index": 0, @@ -17,19 +23,13 @@ "index": 2, "logprobs": null, "text": " severely flawed and often has a substandard" - }, - { - "finish_reason": "length", - "index": 3, - "logprobs": null, - "text": "hd20220811-" } ], - "created": 1713284455, + "created": 1722014725, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native", + "system_fingerprint": "2.2.1-dev0-native", "usage": { "completion_tokens": 36, "prompt_tokens": 8, diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json index e84135cf..6b45cf6b 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 4, "prefill": [ { "id": 100000, @@ -16,7 +16,7 @@ }, { "id": 3102, - "logprob": -11.1875, + "logprob": -11.25, "text": " request" } ], @@ -30,7 +30,7 @@ }, { "id": 10081, - "logprob": -0.36914062, + "logprob": -0.41210938, "special": false, "text": " successfully" }, @@ -41,49 +41,13 @@ "text": "." }, { - "id": 185, - "logprob": 0.0, - "special": false, - "text": "\n" - }, - { - "id": 1380, - "logprob": -0.38671875, - "special": false, - "text": "We" - }, - { - "id": 543, - "logprob": -0.12695312, - "special": false, - "text": " will" - }, - { - "id": 752, - "logprob": -0.20117188, - "special": false, - "text": " get" - }, - { - "id": 279, - "logprob": 0.0, - "special": false, - "text": " in" - }, - { - "id": 5402, - "logprob": 0.0, - "special": false, - "text": " touch" - }, - { - "id": 366, - "logprob": 0.0, - "special": false, - "text": " with" + "id": 100001, + "logprob": -0.16015625, + "special": true, + "text": "<|end▁of▁sentence|>" } ], "top_tokens": null }, - "generated_text": "Test request sent successfully.\nWe will get in touch with" + "generated_text": "Test request sent successfully." } diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json index a4b9784a..e365829a 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -17,7 +17,7 @@ }, { "id": 3102, - "logprob": -11.1875, + "logprob": -11.25, "text": " request" } ], @@ -25,68 +25,68 @@ "tokens": [ { "id": 185, - "logprob": -1.5546875, + "logprob": -1.546875, "special": false, "text": "\n" }, { "id": 549, - "logprob": -2.8125, + "logprob": -2.859375, "special": false, "text": "The" }, { "id": 1727, - "logprob": -2.375, + "logprob": -2.359375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.890625, + "logprob": -0.83203125, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.1484375, + "logprob": -1.125, "special": false, "text": " is" }, { "id": 245, - "logprob": -1.5390625, + "logprob": -1.5703125, "special": false, "text": " a" }, { - "id": 3102, - "logprob": -2.609375, + "id": 3412, + "logprob": -2.578125, "special": false, - "text": " request" + "text": " document" }, { - "id": 327, - "logprob": -0.75, + "id": 344, + "logprob": -1.125, "special": false, - "text": " for" + "text": " that" }, { - "id": 245, - "logprob": -1.1171875, + "id": 317, + "logprob": -1.6953125, "special": false, - "text": " a" + "text": " is" }, { - "id": 1727, - "logprob": -0.90625, + "id": 1222, + "logprob": -1.75, "special": false, - "text": " test" + "text": " used" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a request for a test" + "generated_text": "\nThe test request is a document that is used" }, { "details": { @@ -114,68 +114,68 @@ "tokens": [ { "id": 185, - "logprob": -1.5546875, + "logprob": -1.546875, "special": false, "text": "\n" }, { "id": 549, - "logprob": -2.8125, + "logprob": -2.859375, "special": false, "text": "The" }, { "id": 1727, - "logprob": -2.375, + "logprob": -2.359375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.890625, + "logprob": -0.83203125, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.1484375, + "logprob": -1.125, "special": false, "text": " is" }, { "id": 245, - "logprob": -1.5390625, + "logprob": -1.5703125, "special": false, "text": " a" }, { - "id": 3102, - "logprob": -2.609375, + "id": 3412, + "logprob": -2.578125, "special": false, - "text": " request" + "text": " document" }, { - "id": 327, - "logprob": -0.75, + "id": 344, + "logprob": -1.125, "special": false, - "text": " for" + "text": " that" }, { - "id": 245, - "logprob": -1.1171875, + "id": 317, + "logprob": -1.6953125, "special": false, - "text": " a" + "text": " is" }, { - "id": 1727, - "logprob": -0.90625, + "id": 1222, + "logprob": -1.75, "special": false, - "text": " test" + "text": " used" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a request for a test" + "generated_text": "\nThe test request is a document that is used" }, { "details": { @@ -203,68 +203,68 @@ "tokens": [ { "id": 185, - "logprob": -1.5546875, + "logprob": -1.546875, "special": false, "text": "\n" }, { "id": 549, - "logprob": -2.8125, + "logprob": -2.859375, "special": false, "text": "The" }, { "id": 1727, - "logprob": -2.375, + "logprob": -2.359375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.890625, + "logprob": -0.83203125, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.1484375, + "logprob": -1.125, "special": false, "text": " is" }, { "id": 245, - "logprob": -1.5390625, + "logprob": -1.5703125, "special": false, "text": " a" }, { - "id": 3102, - "logprob": -2.609375, + "id": 3412, + "logprob": -2.578125, "special": false, - "text": " request" + "text": " document" }, { - "id": 327, - "logprob": -0.75, + "id": 344, + "logprob": -1.125, "special": false, - "text": " for" + "text": " that" }, { - "id": 245, - "logprob": -1.1171875, + "id": 317, + "logprob": -1.6953125, "special": false, - "text": " a" + "text": " is" }, { - "id": 1727, - "logprob": -0.90625, + "id": 1222, + "logprob": -1.75, "special": false, - "text": " test" + "text": " used" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a request for a test" + "generated_text": "\nThe test request is a document that is used" }, { "details": { @@ -292,67 +292,67 @@ "tokens": [ { "id": 185, - "logprob": -1.5546875, + "logprob": -1.546875, "special": false, "text": "\n" }, { "id": 549, - "logprob": -2.8125, + "logprob": -2.859375, "special": false, "text": "The" }, { "id": 1727, - "logprob": -2.375, + "logprob": -2.359375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.890625, + "logprob": -0.83203125, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.1484375, + "logprob": -1.125, "special": false, "text": " is" }, { "id": 245, - "logprob": -1.5390625, + "logprob": -1.5703125, "special": false, "text": " a" }, { - "id": 3102, - "logprob": -2.609375, + "id": 3412, + "logprob": -2.578125, "special": false, - "text": " request" + "text": " document" }, { - "id": 327, - "logprob": -0.75, + "id": 344, + "logprob": -1.125, "special": false, - "text": " for" + "text": " that" }, { - "id": 245, - "logprob": -1.1171875, + "id": 317, + "logprob": -1.6953125, "special": false, - "text": " a" + "text": " is" }, { - "id": 1727, - "logprob": -0.90625, + "id": 1222, + "logprob": -1.75, "special": false, - "text": " test" + "text": " used" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a request for a test" + "generated_text": "\nThe test request is a document that is used" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json index 80f0d053..8829f9fe 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -10.0, + "logprob": -10.0625, "text": "Test" }, { "id": 3853, - "logprob": -10.875, + "logprob": -11.0, "text": " request" } ], @@ -24,7 +24,7 @@ "tokens": [ { "id": 1736, - "logprob": -2.09375, + "logprob": -2.03125, "special": false, "text": " form" }, @@ -42,48 +42,48 @@ }, { "id": 2121, - "logprob": -1.8203125, + "logprob": -1.8125, "special": false, "text": " test" }, { "id": 3853, - "logprob": -0.23242188, + "logprob": -0.24121094, "special": false, "text": " request" }, { "id": 1736, - "logprob": -0.08544922, + "logprob": -0.100097656, "special": false, "text": " form" }, { "id": 603, - "logprob": -0.9375, + "logprob": -0.9453125, "special": false, "text": " is" }, { - "id": 1671, - "logprob": -1.671875, + "id": 476, + "logprob": -1.703125, "special": false, - "text": " used" + "text": " a" }, { - "id": 577, - "logprob": -0.40429688, + "id": 4551, + "logprob": -2.453125, "special": false, - "text": " to" + "text": " document" }, { - "id": 3853, - "logprob": -1.1875, + "id": 674, + "logprob": -0.796875, "special": false, - "text": " request" + "text": " that" } ], "top_tokens": null }, - "generated_text": " form\n\nThe test request form is used to request" + "generated_text": " form\n\nThe test request form is a document that" } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json index 8253dc96..0b840bfd 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -10.0, + "logprob": -10.0625, "text": "Test" }, { "id": 3853, - "logprob": -10.875, + "logprob": -11.0, "text": " request" } ], @@ -24,7 +24,7 @@ "tokens": [ { "id": 7539, - "logprob": -0.73046875, + "logprob": -0.609375, "special": false, "text": " forms" }, @@ -36,7 +36,7 @@ }, { "id": 671, - "logprob": -1.703125, + "logprob": -1.5546875, "special": false, "text": " an" }, @@ -66,24 +66,24 @@ }, { "id": 11859, - "logprob": -1.6953125, + "logprob": -1.953125, "special": false, "text": " lab" }, { "id": 2185, - "logprob": -1.3125, + "logprob": -1.7734375, "special": false, "text": " process" }, { - "id": 578, - "logprob": -1.5, + "id": 235265, + "logprob": 0.0, "special": false, - "text": " and" + "text": "." } ], "top_tokens": null }, - "generated_text": "Test request forms are an essential part of the lab process and" + "generated_text": "Test request forms are an essential part of the lab process." } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json index 7a168b2e..bc80a0f9 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -9.65625, + "logprob": -9.640625, "text": "Test" }, { "id": 3853, - "logprob": -10.3671875, + "logprob": -10.375, "text": " request" } ], @@ -24,66 +24,66 @@ "tokens": [ { "id": 604, - "logprob": -0.36938477, + "logprob": -0.2824707, "special": false, "text": " for" }, { - "id": 235248, - "logprob": -1.8046875, + "id": 573, + "logprob": -0.19030762, "special": false, - "text": " " + "text": " the" }, { - "id": 235274, - "logprob": -0.46240234, + "id": 16819, + "logprob": -1.4892578, "special": false, - "text": "1" + "text": " detection" }, { - "id": 235284, - "logprob": -1.7460938, + "id": 576, + "logprob": -0.7011719, "special": false, - "text": "2" + "text": " of" }, { - "id": 235265, - "logprob": -1.9443359, + "id": 573, + "logprob": -2.0195312, "special": false, - "text": "." + "text": " the" }, { - "id": 235284, - "logprob": -1.4550781, - "special": false, - "text": "2" - }, - { - "id": 235308, - "logprob": -1.0205078, - "special": false, - "text": "5" - }, - { - "id": 235290, - "logprob": -1.0283203, - "special": false, - "text": "-" - }, - { - "id": 235274, - "logprob": -1.2783203, - "special": false, - "text": "1" - }, - { - "id": 235284, + "id": 8566, "logprob": 0.0, "special": false, - "text": "2" + "text": " presence" + }, + { + "id": 689, + "logprob": -0.16491699, + "special": false, + "text": " or" + }, + { + "id": 14862, + "logprob": 0.0, + "special": false, + "text": " absence" + }, + { + "id": 576, + "logprob": -0.9946289, + "special": false, + "text": " of" + }, + { + "id": 671, + "logprob": -0.5263672, + "special": false, + "text": " an" } ], "top_tokens": null }, - "generated_text": "Test request for 12.25-12" + "generated_text": "Test request for the detection of the presence or absence of an" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index 89e02c07..164e3cf2 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -11,17 +11,17 @@ }, { "id": 1459, - "logprob": -5.6328125, + "logprob": -5.625, "text": " print" }, { "id": 81, - "logprob": -1.6035156, + "logprob": -1.6064453, "text": "_" }, { "id": 7656, - "logprob": -5.9882812, + "logprob": -5.9921875, "text": "hello" } ], @@ -29,7 +29,7 @@ "tokens": [ { "id": 2262, - "logprob": -0.042999268, + "logprob": -0.045715332, "special": false, "text": "():" }, @@ -59,7 +59,7 @@ }, { "id": 10896, - "logprob": -0.38549805, + "logprob": -0.3659668, "special": false, "text": " World" }, @@ -113,7 +113,7 @@ }, { "id": 426, - "logprob": 0.0, + "logprob": -0.051635742, "special": false, "text": "name" }, @@ -323,7 +323,7 @@ }, { "id": 313, - "logprob": -0.6328125, + "logprob": -0.6933594, "special": false, "text": " \"" }, @@ -387,7 +387,8 @@ "special": false, "text": " print" } - ] + ], + "top_tokens": null }, "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json index 38117272..d882b82a 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -11,12 +11,12 @@ }, { "id": 1489, - "logprob": -5.2617188, + "logprob": -5.265625, "text": " print" }, { "id": 100, - "logprob": -0.38476562, + "logprob": -0.38549805, "text": "_" }, { @@ -29,7 +29,7 @@ "tokens": [ { "id": 2284, - "logprob": -0.296875, + "logprob": -0.31323242, "special": false, "text": "():" }, @@ -53,19 +53,19 @@ }, { "id": 8302, - "logprob": -0.28125, + "logprob": -0.26611328, "special": false, "text": "Hello" }, { "id": 10914, - "logprob": -0.79248047, + "logprob": -0.7817383, "special": false, "text": " World" }, { "id": 16013, - "logprob": -0.61816406, + "logprob": -0.6328125, "special": false, "text": "!\")" }, @@ -83,7 +83,7 @@ }, { "id": 610, - "logprob": -0.4091797, + "logprob": -0.4086914, "special": false, "text": "def" }, @@ -113,7 +113,7 @@ }, { "id": 444, - "logprob": -0.21655273, + "logprob": -0.21826172, "special": false, "text": "name" }, @@ -160,16 +160,28 @@ "text": "Hello" }, { - "id": 332, - "logprob": -0.034698486, + "id": 925, + "logprob": -3.3476562, "special": false, - "text": " \"" + "text": " %" }, { - "id": 494, + "id": 120, "logprob": 0.0, "special": false, - "text": " +" + "text": "s" + }, + { + "id": 11571, + "logprob": -0.10021973, + "special": false, + "text": "!\"" + }, + { + "id": 925, + "logprob": 0.0, + "special": false, + "text": " %" }, { "id": 655, @@ -178,22 +190,10 @@ "text": " name" }, { - "id": 494, - "logprob": -0.20141602, - "special": false, - "text": " +" - }, - { - "id": 332, + "id": 46, "logprob": 0.0, "special": false, - "text": " \"" - }, - { - "id": 16013, - "logprob": 0.0, - "special": false, - "text": "!\")" + "text": ")" }, { "id": 222, @@ -251,7 +251,7 @@ }, { "id": 400, - "logprob": 0.0, + "logprob": -0.074279785, "special": false, "text": "age" }, @@ -310,34 +310,22 @@ "text": "Hello" }, { - "id": 332, + "id": 925, "logprob": 0.0, "special": false, - "text": " \"" + "text": " %" }, { - "id": 494, + "id": 120, "logprob": 0.0, "special": false, - "text": " +" + "text": "s" }, { - "id": 655, - "logprob": 0.0, + "id": 49, + "logprob": -0.07891846, "special": false, - "text": " name" - }, - { - "id": 494, - "logprob": 0.0, - "special": false, - "text": " +" - }, - { - "id": 3021, - "logprob": -0.5761719, - "special": false, - "text": " \"," + "text": "," }, { "id": 863, @@ -352,43 +340,55 @@ "text": " are" }, { - "id": 332, + "id": 925, "logprob": 0.0, "special": false, - "text": " \"" + "text": " %" }, { - "id": 494, + "id": 105, "logprob": 0.0, "special": false, - "text": " +" + "text": "d" }, { - "id": 615, + "id": 11339, "logprob": 0.0, "special": false, - "text": " str" + "text": " years" }, { - "id": 45, + "id": 3627, "logprob": 0.0, "special": false, - "text": "(" + "text": " old" }, { - "id": 400, + "id": 11571, "logprob": 0.0, "special": false, - "text": "age" + "text": "!\"" }, { - "id": 46, + "id": 925, "logprob": 0.0, "special": false, - "text": ")" + "text": " %" + }, + { + "id": 327, + "logprob": 0.0, + "special": false, + "text": " (" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" } ], "top_tokens": null }, - "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name" } diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json index eaba5078..9079b3bd 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -14,55 +14,55 @@ }, { "id": 187, - "logprob": -0.26953125, + "logprob": -0.35742188, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.1953125, + "logprob": -1.1015625, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.53515625, + "logprob": -0.5234375, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.625, + "logprob": -0.55078125, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.6796875, + "logprob": -0.6640625, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0, + "logprob": -2.0625, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3125, + "logprob": -2.375, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0028533936, + "logprob": -0.0029144287, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.265625, + "logprob": -1.2734375, "special": false, "text": " machine" } diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json index 85e9a9e0..ef88926c 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -52,7 +52,7 @@ }, { "id": 9830, - "logprob": -1.65625, + "logprob": -2.25, "special": false, "text": " colors" }, @@ -64,13 +64,13 @@ }, { "id": 329, - "logprob": -2.4375, + "logprob": -2.171875, "special": false, "text": " A" }, { "id": 1180, - "logprob": -1.953125, + "logprob": -2.046875, "special": false, "text": " number" }, From 247a29f77cb7eb4fc33c654ee197064e2bdc5004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 30 Jul 2024 15:16:20 +0200 Subject: [PATCH 176/340] server quantize: store quantizer config in standard format (#2299) - Create `quantization_config` option in the model config. - Don't store the quantizer config in tensors anymore. --- .../layers/gptq/quantize.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 42c986b0..b0086ea0 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -17,7 +17,7 @@ from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error -from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight DEV = torch.device("cuda:0") @@ -897,7 +897,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, - weights_loader=DefaultWeightsLoader(), + weights_loader=DefaultWeightsLoader(UnquantizedWeight), ) hooks = [] for name, module in model.named_modules(): @@ -960,9 +960,6 @@ def quantize( state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} - state_dict["gptq_bits"] = torch.LongTensor([bits]) - state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) - state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( @@ -994,6 +991,15 @@ def quantize( f"index located at {save_index_file}." ) config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.quantization_config = { + "bits": bits, + "group_size": groupsize, + "damp_percent": percdamp, + "desc_act": act_order, + "static_groups": False, + "sym": sym, + "quant_method": "gptq", + } config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer") From 120d5773e8ce58b5adcaeb526014bc9046092951 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 31 Jul 2024 10:33:10 +0200 Subject: [PATCH 177/340] Rebase TRT-llm (#2331) * wip wip refacto refacto Initial setup for CXX binding to TRTLLM Working FFI call for TGI and TRTLLM backend Remove unused parameters annd force tokenizer name to be set Overall build TRTLLM and deps through CMake build system Enable end to end CMake build First version loading engines and making it ready for inference Remembering to check how we can detect support for chunked context Move to latest TensorRT-LLM version Specify which default log level to use depending on CMake build type make leader executor mode working unconditionally call InitializeBackend on the FFI layer bind to CUDA::nvml to retrieve compute capabilities at runtime updated logic and comment to detect cuda compute capabilities implement the Stream method to send new tokens through a callback use spdlog release 1.14.1 moving forward update trtllm to latest version a96cccafcf6365c128f004f779160951f8c0801c correctly tell cmake to build dependent tensorrt-llm required libraries create cmake install target to put everything relevant in installation folder add auth_token CLI argument to provide hf hub authentification token allow converting huggingface::tokenizers error to TensorRtLlmBackendError use correct include for spdlog include guard to build example in cmakelists working setup of the ffi layer remove fmt import use external fmt lib end to end ffi flow working make sure to track include/ffi.h to trigger rebuild from cargo impl the rust backend which currently cannot move the actual computation in background thread expose shutdown function at ffi layer impl RwLock scenario for TensorRtLllmBackend oops missing c++ backend definitions compute the number of maximum new tokens for each request independently make sure the context is not dropped in the middle of the async decoding. remove unnecessary log add all the necessary plumbery to return the generated content update invalid doc in cpp file correctly forward back the log probabilities remove unneeded scope variable for now refactor Stream impl for Generation to factorise code expose the internal missing start/queue timestamp forward tgi parameters rep/freq penalty add some more validation about grammar not supported define a shared struct to hold the result of a decoding step expose information about potential error happening while decoding remove logging add logging in case of decoding error make sure executor_worker is provided add initial Dockerfile for TRTLLM backend add some more information in CMakeLists.txt to correctly install executorWorker add some more information in CMakeLists.txt to correctly find and install nvrtc wrapper simplify prebuilt trtllm libraries name definition do the same name definition stuff for tensorrt_llm_executor_static leverage pkg-config to probe libraries paths and reuse new install structure from cmake fix bad copy/past missing nvinfer linkage direction align all the linker search dependency add missing pkgconfig folder for MPI in Dockerfile correctly setup linking search path for runtime layer fix missing / before tgi lib path adding missing ld_library_path for cuda stubs in Dockerfile update tgi entrypoint commenting out Python part for TensorRT installation refactored docker image move to TensorRT-LLM v0.11.0 make docker linter happy with same capitalization rule fix typo refactor the compute capabilities detection along with num gpus update TensorRT-LLM to latest version update TensorRT install script to latest update build.rs to link to cuda 12.5 add missing dependant libraries for linking clean up a bit install to decoder_attention target add some custom stuff for nccl linkage fix envvar CARGO_CFG_TARGET_ARCH set at runtime vs compile time use std::env::const::ARCH make sure variable live long enough... look for cuda 12.5 add some more basic info in README.md * Rebase. * Fix autodocs. * Let's try to enable trtllm backend. * Ignore backends/v3 by default. * Fixing client. * Fix makefile + autodocs. * Updating the schema thing + redocly. * Fix trtllm lint. * Adding pb files ? * Remove cargo fmt temporarily. * ? * Tmp. * Remove both check + clippy ? * Backporting telemetry. * Backporting 457fb0a1 * Remove PB from git. * Fixing PB with default member backends/client * update TensorRT-LLM to latest version * provided None for api_key * link against libtensorrt_llm and not libtensorrt-llm --------- Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Co-authored-by: Morgan Funtowicz --- .dockerignore | 2 + .gitignore | 4 + .pre-commit-config.yaml | 2 +- Cargo.lock | 643 ++++++++----- Cargo.toml | 21 +- Dockerfile | 1 + Dockerfile.trtllm | 23 + Dockerfile_amd | 2 + Dockerfile_intel | 2 + Makefile | 6 +- {router => backends}/client/Cargo.toml | 0 {router => backends}/client/build.rs | 0 {router => backends}/client/src/lib.rs | 0 {router => backends}/client/src/v2/client.rs | 0 {router => backends}/client/src/v2/mod.rs | 0 .../client/src/v2/sharded_client.rs | 0 {router => backends}/client/src/v3/client.rs | 0 {router => backends}/client/src/v3/mod.rs | 0 .../client/src/v3/sharded_client.rs | 0 {router => backends}/grpc-metadata/Cargo.toml | 0 {router => backends}/grpc-metadata/src/lib.rs | 0 backends/trtllm/CMakeLists.txt | 63 ++ backends/trtllm/Cargo.toml | 26 + backends/trtllm/Dockerfile | 100 ++ backends/trtllm/README.md | 46 + backends/trtllm/build.rs | 150 +++ backends/trtllm/cmake/fmt.cmake | 6 + backends/trtllm/cmake/json.cmake | 5 + backends/trtllm/cmake/spdlog.cmake | 17 + backends/trtllm/cmake/trtllm.cmake | 42 + .../trtllm/cmake/utils/detect_cuda_arch.cu | 0 backends/trtllm/include/backend.h | 121 +++ backends/trtllm/include/ffi.h | 75 ++ backends/trtllm/include/hardware.h | 59 ++ backends/trtllm/lib/backend.cpp | 146 +++ backends/trtllm/scripts/install_tensorrt.sh | 111 +++ backends/trtllm/src/backend.rs | 329 +++++++ backends/trtllm/src/errors.rs | 15 + backends/trtllm/src/ffi.cpp | 84 ++ backends/trtllm/src/lib.rs | 78 ++ backends/trtllm/src/main.rs | 166 ++++ backends/trtllm/tests/infer_test.cpp | 14 + backends/v3/Cargo.toml | 66 ++ backends/v3/build.rs | 19 + backends/v3/src/backend.rs | 501 ++++++++++ .../v3 => backends/v3/src}/block_allocator.rs | 0 backends/v3/src/client/grpc_client.rs | 284 ++++++ backends/v3/src/client/mod.rs | 76 ++ backends/v3/src/client/sharded_client.rs | 260 ++++++ backends/v3/src/lib.rs | 142 +++ backends/v3/src/main.rs | 208 +++++ .../src/infer/v3 => backends/v3/src}/queue.rs | 36 +- benchmark/Cargo.toml | 2 +- docs/openapi.json | 35 - router/Cargo.toml | 14 +- .../{v3/scheduler.rs => chat_template.rs} | 557 ++---------- router/src/infer/health.rs | 34 - router/src/infer/mod.rs | 280 +----- router/src/infer/tool_grammar.rs | 135 +++ router/src/infer/v2/mod.rs | 2 +- router/src/infer/v2/scheduler.rs | 8 +- router/src/infer/v3/mod.rs | 5 - router/src/lib.rs | 39 +- router/src/logging.rs | 81 ++ router/src/{main.rs => main.rs.back} | 0 router/src/server.rs | 857 ++++++++++++------ router/src/usage_stats.rs | 30 +- router/src/validation.rs | 70 +- update_doc.py | 14 +- 69 files changed, 4719 insertions(+), 1395 deletions(-) create mode 100644 Dockerfile.trtllm rename {router => backends}/client/Cargo.toml (100%) rename {router => backends}/client/build.rs (100%) rename {router => backends}/client/src/lib.rs (100%) rename {router => backends}/client/src/v2/client.rs (100%) rename {router => backends}/client/src/v2/mod.rs (100%) rename {router => backends}/client/src/v2/sharded_client.rs (100%) rename {router => backends}/client/src/v3/client.rs (100%) rename {router => backends}/client/src/v3/mod.rs (100%) rename {router => backends}/client/src/v3/sharded_client.rs (100%) rename {router => backends}/grpc-metadata/Cargo.toml (100%) rename {router => backends}/grpc-metadata/src/lib.rs (100%) create mode 100644 backends/trtllm/CMakeLists.txt create mode 100644 backends/trtllm/Cargo.toml create mode 100644 backends/trtllm/Dockerfile create mode 100644 backends/trtllm/README.md create mode 100644 backends/trtllm/build.rs create mode 100644 backends/trtllm/cmake/fmt.cmake create mode 100644 backends/trtllm/cmake/json.cmake create mode 100644 backends/trtllm/cmake/spdlog.cmake create mode 100644 backends/trtllm/cmake/trtllm.cmake create mode 100644 backends/trtllm/cmake/utils/detect_cuda_arch.cu create mode 100644 backends/trtllm/include/backend.h create mode 100644 backends/trtllm/include/ffi.h create mode 100644 backends/trtllm/include/hardware.h create mode 100644 backends/trtllm/lib/backend.cpp create mode 100755 backends/trtllm/scripts/install_tensorrt.sh create mode 100644 backends/trtllm/src/backend.rs create mode 100644 backends/trtllm/src/errors.rs create mode 100644 backends/trtllm/src/ffi.cpp create mode 100644 backends/trtllm/src/lib.rs create mode 100644 backends/trtllm/src/main.rs create mode 100644 backends/trtllm/tests/infer_test.cpp create mode 100644 backends/v3/Cargo.toml create mode 100644 backends/v3/build.rs create mode 100644 backends/v3/src/backend.rs rename {router/src/infer/v3 => backends/v3/src}/block_allocator.rs (100%) create mode 100644 backends/v3/src/client/grpc_client.rs create mode 100644 backends/v3/src/client/mod.rs create mode 100644 backends/v3/src/client/sharded_client.rs create mode 100644 backends/v3/src/lib.rs create mode 100644 backends/v3/src/main.rs rename {router/src/infer/v3 => backends/v3/src}/queue.rs (95%) rename router/src/infer/{v3/scheduler.rs => chat_template.rs} (67%) delete mode 100644 router/src/infer/health.rs create mode 100644 router/src/infer/tool_grammar.rs delete mode 100644 router/src/infer/v3/mod.rs create mode 100644 router/src/logging.rs rename router/src/{main.rs => main.rs.back} (100%) diff --git a/.dockerignore b/.dockerignore index c69283ec..1c641e7a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,5 @@ aml target server/transformers server/flash-attention +cmake-build-debug/ +cmake-build-release/ diff --git a/.gitignore b/.gitignore index e9ad1808..0de8b848 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ target router/tokenizer.json *__pycache__* +backends/v3/src/client/pb +backends/client/src/v2/pb +backends/client/src/v3/pb + # ROCm auto-generated files *.hip server/exllamav2_kernels/exllamav2_kernels/hip/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb3879b2..6f5e685e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,8 +13,8 @@ repos: - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 hooks: - - id: fmt - id: cargo-check + - id: fmt - id: clippy - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.0 diff --git a/Cargo.lock b/Cargo.lock index e7457b73..c4cf8dc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ "once_cell", "serde", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -48,9 +48,9 @@ checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" [[package]] name = "anstream" -version = "0.6.14" +version = "0.6.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" dependencies = [ "anstyle", "anstyle-parse", @@ -63,33 +63,33 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "anstyle-parse" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", "windows-sys 0.52.0", @@ -121,7 +121,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -160,18 +160,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -234,9 +234,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.7.3" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7d844e282b4b56750b2d4e893b2205581ded8709fddd2b6aa5418c150ca877" +checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77" dependencies = [ "aws-lc-sys", "mirai-annotations", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.18.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a2c29203f6bf296d01141cc8bb9dbd5ecd4c27843f2ee0767bcd5985a927da" +checksum = "2e89b6941c2d1a7045538884d6e760ccfffdf8e1ffc2613d8efa74305e1f3752" dependencies = [ "bindgen", "cc", @@ -272,7 +272,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "itoa", "matchit", "memchr", @@ -302,9 +302,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "itoa", "matchit", @@ -352,7 +352,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -433,7 +433,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.68", + "syn 2.0.72", "which", ] @@ -472,9 +472,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415f8399438eb5e4b2f73ed3152a3448b98149dda642a957ee704e1daa5cf1d8" +checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" [[package]] name = "block-buffer" @@ -487,9 +487,9 @@ dependencies = [ [[package]] name = "built" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" +checksum = "236e6289eda5a812bc6b53c3b024039382a2895fbbeef2d748b2931546d392c4" [[package]] name = "bumpalo" @@ -523,9 +523,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "camino" @@ -600,9 +600,9 @@ checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" [[package]] name = "clap" -version = "4.5.7" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" +checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" dependencies = [ "clap_builder", "clap_derive", @@ -610,9 +610,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.7" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" +checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" dependencies = [ "anstream", "anstyle", @@ -646,9 +646,9 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "colorchoice" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "console" @@ -740,7 +740,7 @@ dependencies = [ "bitflags 2.6.0", "crossterm_winapi", "libc", - "mio", + "mio 0.8.11", "parking_lot", "signal-hook", "signal-hook-mio", @@ -804,10 +804,54 @@ dependencies = [ ] [[package]] -name = "darling" -version = "0.20.9" +name = "cxx" +version = "1.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.72", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -815,27 +859,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "darling_macro" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -865,7 +909,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -875,7 +919,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1155,7 +1199,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1403,9 +1447,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1420,7 +1464,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1438,9 +1482,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", @@ -1462,16 +1506,16 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1489,10 +1533,10 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "log", - "rustls 0.23.10", + "rustls 0.23.12", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1506,7 +1550,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.29", + "hyper 0.14.30", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1640,7 +1684,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1663,12 +1707,12 @@ dependencies = [ [[package]] name = "image" -version = "0.25.1" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +checksum = "99314c8a2152b8ddb211f924cdae532d8c5e4c8bb54728e12fff1b0cd5963a10" dependencies = [ "bytemuck", - "byteorder", + "byteorder-lite", "color_quant", "exr", "gif", @@ -1686,12 +1730,12 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" +checksum = "f79afb8cbee2ef20f59ccd477a218c12a93943d075b492015ecb1bb81f8ee904" dependencies = [ "byteorder-lite", - "thiserror", + "quick-error", ] [[package]] @@ -1781,9 +1825,9 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is_terminal_polyfill" -version = "1.70.0" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "iso8601" @@ -1829,9 +1873,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -1918,12 +1962,12 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1942,6 +1986,15 @@ dependencies = [ "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1966,9 +2019,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loop9" @@ -2023,7 +2076,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" dependencies = [ "cfg-if", - "rayon", ] [[package]] @@ -2044,13 +2096,13 @@ dependencies = [ [[package]] name = "metrics-exporter-prometheus" -version = "0.15.1" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0af7a0d7ced10c0151f870e5e3f3f8bc9ffc5992d32873566ca1f9169ae776" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ "base64 0.22.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-rustls", "hyper-util", "indexmap 2.2.6", @@ -2086,9 +2138,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mime_guess" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" dependencies = [ "mime", "unicase", @@ -2096,18 +2148,18 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4" +checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", ] [[package]] name = "minijinja-contrib" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07" +checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" dependencies = [ "minijinja", "serde", @@ -2141,6 +2193,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + [[package]] name = "mirai-annotations" version = "1.12.0" @@ -2165,7 +2229,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2231,7 +2295,7 @@ dependencies = [ "bytes", "futures", "hostname", - "hyper 0.14.29", + "hyper 0.14.30", "muxado", "once_cell", "parking_lot", @@ -2316,9 +2380,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2353,7 +2417,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2424,9 +2488,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.0" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" dependencies = [ "memchr", ] @@ -2461,9 +2525,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -2482,7 +2546,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2493,9 +2557,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", @@ -2529,6 +2593,20 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "opentelemetry" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-otlp" version = "0.13.0" @@ -2622,7 +2700,27 @@ dependencies = [ "glob", "once_cell", "opentelemetry 0.21.0", - "ordered-float 4.2.0", + "ordered-float 4.2.2", + "percent-encoding", + "rand", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "lazy_static", + "once_cell", + "opentelemetry 0.23.0", + "ordered-float 4.2.2", "percent-encoding", "rand", "thiserror", @@ -2645,9 +2743,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.2.0" +version = "4.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6" dependencies = [ "num-traits", ] @@ -2689,7 +2787,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2731,7 +2829,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2767,9 +2865,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "powerfmt" @@ -2779,9 +2877,12 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +dependencies = [ + "zerocopy 0.6.6", +] [[package]] name = "prettyplease" @@ -2790,7 +2891,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2842,7 +2943,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2882,7 +2983,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.68", + "syn 2.0.72", "tempfile", ] @@ -3044,24 +3145,23 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.7" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67376f469e7e7840d0040bbf4b9b3334005bb167f814621326e4c7ab8cd6e944" +checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" dependencies = [ "avif-serialize", "imgref", "loop9", "quick-error", "rav1e", - "rayon", "rgb", ] [[package]] name = "raw-cpuid" -version = "11.0.2" +version = "11.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" dependencies = [ "bitflags 2.6.0", ] @@ -3175,7 +3275,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-tls", "ipnet", "js-sys", @@ -3203,9 +3303,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.37" +version = "0.8.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" +checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" dependencies = [ "bytemuck", ] @@ -3242,9 +3342,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19549741604902eb99a7ed0ee177a0663ee1eda51a29f71401f166e47e77806a" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -3337,9 +3437,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.10" +version = "0.23.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" dependencies = [ "aws-lc-rs", "log", @@ -3352,9 +3452,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile 2.1.2", @@ -3390,9 +3490,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3436,6 +3536,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "sct" version = "0.7.1" @@ -3448,9 +3554,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3461,9 +3567,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" dependencies = [ "core-foundation-sys", "libc", @@ -3480,31 +3586,32 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3521,9 +3628,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" dependencies = [ "serde", ] @@ -3578,12 +3685,12 @@ dependencies = [ [[package]] name = "signal-hook-mio" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", - "mio", + "mio 0.8.11", "signal-hook", ] @@ -3709,7 +3816,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -3731,9 +3838,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.68" +version = "2.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" dependencies = [ "proc-macro2", "quote", @@ -3832,9 +3939,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" [[package]] name = "tempfile" @@ -3937,7 +4044,6 @@ dependencies = [ "serde", "serde_json", "sysinfo", - "text-generation-client", "thiserror", "tokenizers", "tokio", @@ -3946,30 +4052,79 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", + "ureq", "utoipa", "utoipa-swagger-ui", "uuid", "vergen", ] +[[package]] +name = "text-generation-router-v3" +version = "2.2.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.5", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4071,21 +4226,20 @@ dependencies = [ [[package]] name = "tokio" -version = "1.38.0" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", - "mio", - "num_cpus", + "mio 1.0.1", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -4100,13 +4254,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4136,7 +4290,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.10", + "rustls 0.23.12", "rustls-pki-types", "tokio", ] @@ -4168,9 +4322,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.14" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" dependencies = [ "serde", "serde_spanned", @@ -4180,18 +4334,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" dependencies = [ "indexmap 2.2.6", "serde", @@ -4215,7 +4369,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -4265,7 +4419,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4297,7 +4451,7 @@ dependencies = [ "bitflags 2.6.0", "bytes", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "pin-project-lite", "tower-layer", @@ -4336,7 +4490,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4402,7 +4556,25 @@ dependencies = [ "tracing-core", "tracing-log 0.2.0", "tracing-subscriber", - "web-time", + "web-time 0.2.4", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 1.1.0", ] [[package]] @@ -4591,7 +4763,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4629,7 +4801,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4657,9 +4829,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.3.1" +version = "8.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", "cargo_metadata", @@ -4679,9 +4851,9 @@ checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -4729,7 +4901,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", "wasm-bindgen-shared", ] @@ -4763,7 +4935,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4794,6 +4966,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" @@ -4869,7 +5051,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4878,7 +5060,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4905,7 +5087,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4940,18 +5122,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4968,9 +5150,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4986,9 +5168,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -5004,15 +5186,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -5028,9 +5210,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -5046,9 +5228,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -5064,9 +5246,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -5082,15 +5264,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" dependencies = [ "memchr", ] @@ -5143,22 +5325,43 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" dependencies = [ - "zerocopy-derive", + "byteorder", + "zerocopy-derive 0.6.6", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive 0.7.35", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", ] [[package]] @@ -5239,9 +5442,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index 8a1fce05..dca8fdbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,19 @@ [workspace] members = [ - "benchmark", - "router", - "router/client", - "router/grpc-metadata", - "launcher" + "benchmark", + "backends/v3", + "backends/grpc-metadata", + "backends/trtllm", + "backends/client", + "launcher" +] +default-members = [ + "benchmark", + "backends/v3", + "backends/grpc-metadata", + # "backends/trtllm", + "backends/client", + "launcher" ] resolver = "2" @@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference" base64 = "0.22.0" tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } +metrics = { version = "0.23.0" } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } [profile.release] incremental = true diff --git a/Dockerfile b/Dockerfile index 8bc87a0e..e6f01ec2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json diff --git a/Dockerfile.trtllm b/Dockerfile.trtllm new file mode 100644 index 00000000..4543ae80 --- /dev/null +++ b/Dockerfile.trtllm @@ -0,0 +1,23 @@ +# All the tooling for CUDA +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +WORKDIR /usr/src/tgi/backends/trtllm +RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget + +COPY . /usr/src/tgi +RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh +RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include . +RUN cmake --build build --parallel -t tgi_trtllm_backend_impl + +# All the tooling for Rust +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +WORKDIR /usr/src + +# Include CUDA related libraries and tools to the Rust based image +COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda +COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build +ENV PATH=/usr/local/cuda/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH + +RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3 diff --git a/Dockerfile_amd b/Dockerfile_amd index 0aebeee5..51231638 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Dockerfile_intel b/Dockerfile_intel index 6a803a32..d20f0a01 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Makefile b/Makefile index 27b5efcf..db26e3bf 100644 --- a/Makefile +++ b/Makefile @@ -6,13 +6,13 @@ install-integration-tests: cd clients/python && pip install . install-router: - cd router && cargo install --locked --path . + cargo install --path backends/v3/ install-launcher: - cd launcher && cargo install --locked --path . + cargo install --path launcher/ install-benchmark: - cd benchmark && cargo install --locked --path . + cargo install --path benchmark/ install: install-server install-router install-launcher install-custom-kernels diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml similarity index 100% rename from router/client/Cargo.toml rename to backends/client/Cargo.toml diff --git a/router/client/build.rs b/backends/client/build.rs similarity index 100% rename from router/client/build.rs rename to backends/client/build.rs diff --git a/router/client/src/lib.rs b/backends/client/src/lib.rs similarity index 100% rename from router/client/src/lib.rs rename to backends/client/src/lib.rs diff --git a/router/client/src/v2/client.rs b/backends/client/src/v2/client.rs similarity index 100% rename from router/client/src/v2/client.rs rename to backends/client/src/v2/client.rs diff --git a/router/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs similarity index 100% rename from router/client/src/v2/mod.rs rename to backends/client/src/v2/mod.rs diff --git a/router/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs similarity index 100% rename from router/client/src/v2/sharded_client.rs rename to backends/client/src/v2/sharded_client.rs diff --git a/router/client/src/v3/client.rs b/backends/client/src/v3/client.rs similarity index 100% rename from router/client/src/v3/client.rs rename to backends/client/src/v3/client.rs diff --git a/router/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs similarity index 100% rename from router/client/src/v3/mod.rs rename to backends/client/src/v3/mod.rs diff --git a/router/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs similarity index 100% rename from router/client/src/v3/sharded_client.rs rename to backends/client/src/v3/sharded_client.rs diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml similarity index 100% rename from router/grpc-metadata/Cargo.toml rename to backends/grpc-metadata/Cargo.toml diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs similarity index 100% rename from router/grpc-metadata/src/lib.rs rename to backends/grpc-metadata/src/lib.rs diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt new file mode 100644 index 00000000..425b2d7b --- /dev/null +++ b/backends/trtllm/CMakeLists.txt @@ -0,0 +1,63 @@ +cmake_minimum_required(VERSION 3.20) + +project(tgi-trtllm-backend VERSION 1.0.0) +set(CMAKE_CXX_STANDARD 20) + +include(FetchContent) +include(ExternalProject) + +option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") +set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") + +# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features +find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) + +#### External dependencies #### +include(cmake/fmt.cmake) +include(cmake/json.cmake) +include(cmake/spdlog.cmake) +include(cmake/trtllm.cmake) + +# Let's build TRTLLM as part of CMake +add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") + +# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so +set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE) + +# TGI TRTLLM Backend definition +add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h) +include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +target_include_directories(tgi_trtllm_backend_impl PRIVATE + $ + $ +) +target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) +target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt) + +# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back +install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) + +#### Unit Tests #### +if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) + message(STATUS "Building tests") + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2 + GIT_TAG v3.6.0 + ) + FetchContent_MakeAvailable(Catch2) + + # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp) + # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml) + + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) + include(CTest) + include(Catch) + # catch_discover_tests(tgi_trtllm_backend_tests) +endif () diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 00000000..7079d3d1 --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "text-generation-backends-trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1" +async-stream = "0.3" +cxx = "1.0" +text-generation-router = { path = "../../router" } +tokenizers = { version = "0.19", features = ["hf-hub"] } +tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.15" +clap = { version = "4.5", features = ["derive"] } +thiserror = "1.0.62" +tracing = "0.1" +tracing-opentelemetry = "0.24" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +log = { version = "0.4", features = [] } + +[build-dependencies] +cmake = "0.1" +cxx-build = { version = "1.0", features = ["parallel"] } +pkg-config = "0.3" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile new file mode 100644 index 00000000..60ad03f7 --- /dev/null +++ b/backends/trtllm/Dockerfile @@ -0,0 +1,100 @@ +ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" +ARG OMPI_VERSION="4.1.6" + +# Build dependencies resolver stage +FROM lukemathwalker/cargo-chef:latest AS chef +WORKDIR /usr/src/text-generation-inference + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +# CUDA dependent dependencies resolver stage +FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt update && apt install -y \ + build-essential \ + cmake \ + curl \ + gcc \ + g++ \ + git \ + git-lfs \ + libssl-dev \ + ninja-build \ + pkg-config \ + python3 \ + python3-setuptools \ + tar \ + wget + +ENV TGI_INSTALL_PREFIX=/usr/local/tgi +ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt + +# Install OpenMPI +FROM cuda-builder AS mpi-builder +ARG OMPI_VERSION + +ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" +RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ + mkdir /usr/src/mpi && \ + tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ + cd /usr/src/mpi && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + make -j all && \ + make install && \ + rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" + +# Install TensorRT +FROM cuda-builder AS trt-builder +COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh +RUN chmod +x /opt/install_tensorrt.sh && \ + /opt/install_tensorrt.sh + +# Build Backend +FROM cuda-builder AS tgi-builder +WORKDIR /usr/src/text-generation-inference + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ + chmod -R a+w /root/.rustup && \ + chmod -R a+w /root/.cargo + +ENV PATH="/root/.cargo/bin:$PATH" +RUN cargo install cargo-chef + +# Cache dependencies +COPY --from=planner /usr/src/text-generation-inference/recipe.json . +RUN cargo chef cook --release --recipe-path recipe.json + +# Build actual TGI +ARG CUDA_ARCH_LIST +ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH" +ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" +ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" + +COPY . . +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + +FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime +WORKDIR /usr/local/tgi/bin + +ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" + +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi +COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher + +FROM runtime + +LABEL co.huggingface.vendor="Hugging Face Inc." +LABEL org.opencontainers.image.authors="hardware@hf.co" + +ENTRYPOINT ["./text-generation-launcher"] +CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] diff --git a/backends/trtllm/README.md b/backends/trtllm/README.md new file mode 100644 index 00000000..94064504 --- /dev/null +++ b/backends/trtllm/README.md @@ -0,0 +1,46 @@ +# Text Generation Inference - TensorRT-LLM Backend Implementation + +## Description + +This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API + +## Simplified Request Sequence + +```mermaid +sequenceDiagram + actor User + participant TextGenerationInference.HttpServer + participant TextGenerationInference.TensorRtLlmBackend + participant TextGenerationInference.TensorRtLlmWorkerThread + participant TensorRtLlm.Executor + participant Nvidia.Gpu + User ->> TextGenerationInference.HttpServer: POST /generate + TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters + TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher + activate Nvidia.Gpu + TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier + rect rgb(10, 92, 54) + loop every 100us + rect rgb(15, 81, 50) + alt Acquire lock to query executor + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated + else There are new generated tokens + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted) + rect rgb(11, 110, 79) + alt Generated token is final + TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection + else Generated token is not final + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded + end + end + end + end + deactivate Nvidia.Gpu + end + end + +``` diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 00000000..08638262 --- /dev/null +++ b/backends/trtllm/build.rs @@ -0,0 +1,150 @@ +use cxx_build::CFG; +use pkg_config; +use std::env; +use std::env::consts::ARCH; +use std::path::{absolute, PathBuf}; + +const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; +const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); +const CUDA_REQUIRED_VERSION: &str = "12.5"; +const MPI_REQUIRED_VERSION: &str = "4.1"; +const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); +const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); +const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); + +// Dependencies +const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; +const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; +const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ + ("dylib", "tensorrt_llm"), + ("static", "tensorrt_llm_executor_static"), + ("dylib", "tensorrt_llm_nvrtc_wrapper"), + ("dylib", "nvinfer_plugin_tensorrt_llm"), + ("dylib", "decoder_attention"), +]; + +macro_rules! probe { + ($name: expr, $version: expr) => { + if let Err(_) = pkg_config::probe_library($name) { + pkg_config::probe_library(&format!("{}-{}", $name, $version)) + .expect(&format!("Failed to locate {}", $name)); + } + }; +} + +fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { + // Build the backend implementation through CMake + let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); + let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); + let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default + + let mut install_path = PathBuf::from(install_path); + if !install_path.is_absolute() { + install_path = absolute(out_dir).expect("cannot happen").join(install_path); + } + + let _ = cmake::Config::new(".") + .uses_cxx11() + .generator("Ninja") + .profile(match is_debug { + true => "Debug", + false => "Release", + }) + .env("OPT_LEVEL", opt_level) + .define("CMAKE_INSTALL_PREFIX", &install_path) + .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") + .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) + .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path) + .build(); + + // Additional transitive CMake dependencies + let deps_folder = out_dir.join("build").join("_deps"); + for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { + let dep_name = match is_debug { + true => format!("{}d", dependency), + false => String::from(dependency), + }; + let dep_path = deps_folder.join(format!("{}-build", dependency)); + println!("cargo:rustc-link-search={}", dep_path.display()); + println!("cargo:rustc-link-lib=static={}", dep_name); + } + + // Emit linkage information from the artifacts we just built + let install_lib_path = install_path.join("lib"); + + println!( + r"cargo:warning=Adding link search path: {}", + install_lib_path.display() + ); + println!(r"cargo:rustc-link-search={}", install_lib_path.display()); + + (PathBuf::from(install_path), deps_folder) +} + +fn build_ffi_layer(deps_folder: &PathBuf) { + CFG.include_prefix = "backends/trtllm"; + cxx_build::bridge("src/lib.rs") + .static_flag(true) + .include(deps_folder.join("fmt-src").join("include")) + .include(deps_folder.join("spdlog-src").join("include")) + .include(deps_folder.join("json-src").join("include")) + .include(deps_folder.join("trtllm-src").join("cpp").join("include")) + .include("/usr/local/cuda/include") + .include("/usr/local/tensorrt/include") + .file("src/ffi.cpp") + .std("c++20") + .compile("tgi_trtllm_backend"); + + println!("cargo:rerun-if-changed=CMakeLists.txt"); + println!("cargo:rerun-if-changed=include/backend.h"); + println!("cargo:rerun-if-changed=lib/backend.cpp"); + println!("cargo:rerun-if-changed=include/ffi.h"); + println!("cargo:rerun-if-changed=src/ffi.cpp"); +} + +fn main() { + // Misc variables + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let build_profile = env::var("PROFILE").unwrap(); + let (is_debug, opt_level) = match build_profile.as_ref() { + "debug" => (true, "0"), + _ => (false, "3"), + }; + + // Build the backend + let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); + + // Build the FFI layer calling the backend above + build_ffi_layer(&deps_folder); + + // Emit linkage search path + probe!("ompi", MPI_REQUIRED_VERSION); + + // Probe CUDA & co. with pkg-config + CUDA_TRANSITIVE_DEPS.iter().for_each(|name| { + probe!(name, CUDA_REQUIRED_VERSION); + }); + + // NCCL is slightly trickier because it might not have a pkgconfig installed + let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH); + let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default); + println!(r"cargo:rustc-link-search=native={}", nccl_library_path); + println!("cargo:rustc-link-lib=dylib=nccl"); + + // TensorRT + let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib"); + println!(r"cargo:rustc-link-search=native={}", tensort_library_path); + println!("cargo:rustc-link-lib=dylib=nvinfer"); + + // TensorRT-LLM + TENSORRT_LLM_TRANSITIVE_DEPS + .iter() + .for_each(|(link_type, name)| { + println!("cargo:rustc-link-lib={}={}", link_type, name); + }); + + // Backend + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); +} diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake new file mode 100644 index 00000000..f94a9c56 --- /dev/null +++ b/backends/trtllm/cmake/fmt.cmake @@ -0,0 +1,6 @@ +FetchContent_Declare( + fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt + GIT_TAG 11.0.1 +) +FetchContent_MakeAvailable(fmt) diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake new file mode 100644 index 00000000..29e5753b --- /dev/null +++ b/backends/trtllm/cmake/json.cmake @@ -0,0 +1,5 @@ +fetchcontent_declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz +) +fetchcontent_makeavailable(json) diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake new file mode 100644 index 00000000..c4ee5c97 --- /dev/null +++ b/backends/trtllm/cmake/spdlog.cmake @@ -0,0 +1,17 @@ +set(SPDLOG_USE_FMT ON) +set(SPDLOG_BUILD_SHARED OFF) +set(SPDLOG_FMT_EXTERNAL ON) + +# Define the level at which SPDLOG_ compilation level is defined +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) +else () + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) +endif () + +fetchcontent_declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG v1.14.1 +) +fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake new file mode 100644 index 00000000..e59ad4cf --- /dev/null +++ b/backends/trtllm/cmake/trtllm.cmake @@ -0,0 +1,42 @@ +set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}) + +set(USE_CXX11_ABI ON) +set(BUILD_PYT OFF) +set(BUILD_PYBIND OFF) +set(BUILD_MICRO_BENCHMARKS OFF) +set(BUILD_BENCHMARKS OFF) +set(BUILD_TESTS OFF) +set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST}) + +message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + set(FAST_BUILD ON) + set(NVTX_DISABLE OFF) +else () + set(FAST_BUILD OFF) + set(FAST_MATH ON) + set(NVTX_DISABLE ON) +endif () + +fetchcontent_declare( + trtllm + GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git + GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1 + GIT_SHALLOW FALSE +) +fetchcontent_makeavailable(trtllm) + +message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") +execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") +execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") + +# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here +set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name") +set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}" + CACHE INTERNAL "nvrtc wrapper library path") + +# The same Executor Static library +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name") +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path") diff --git a/backends/trtllm/cmake/utils/detect_cuda_arch.cu b/backends/trtllm/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 00000000..e69de29b diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h new file mode 100644 index 00000000..7990e76b --- /dev/null +++ b/backends/trtllm/include/backend.h @@ -0,0 +1,121 @@ +// +// Created by Morgan Funtowicz on 6/30/24. +// + +#ifndef TGI_TRTLLM_BACKEND_H +#define TGI_TRTLLM_BACKEND_H + +#include +#include +#include +#include + +#include + +#include +#include +#include + +using json = nlohmann::json; +namespace tle = tensorrt_llm::executor; + +namespace huggingface::tgi::backends { + using RequestId = tle::IdType; + using TokenId = tle::TokenIdType; + + /** + * Initialize all the components required by TRTLLM. + * It is required to call this function before attempting to load any engine + */ + void InitializeBackend(); + + /** + * + * @param config TensorRT-LLM configuration object + * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode + * @return + */ + tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); + + /** + * Get the sampling configuration from the parameters provided by TGI + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + tle::SamplingConfig GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + */ + class TensorRtLlmBackend { + private: + const json config; + tle::Executor executor; + + public: + explicit TensorRtLlmBackend( + const std::filesystem::path &engineFolder, + const std::filesystem::path &executorWorker + ); + + /** + * Indicate if the backend is ready to accept incoming request + * @return true if ready, false otherwise + */ + [[nodiscard]] bool IsReady() const; + + /** + * Query the executor for the number of token available for pulling + * @return + */ + [[nodiscard]] size_t NumResponsesReady() const; + + /** + * Submit a new generation task to the executor + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return Request id related to this generation for reference + */ + [[nodiscard]] RequestId Submit( + const std::vector &tokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + * @param requestId The request id to poll the generation results + * @return + */ + std::vector Poll(RequestId requestId); + + /** + * Stop the underlying executor + */ + void Shutdown(); + }; +} + + +#endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h new file mode 100644 index 00000000..fe0be9fc --- /dev/null +++ b/backends/trtllm/include/ffi.h @@ -0,0 +1,75 @@ +// +// Created by mfuntowicz on 7/11/24. +// + +#ifndef TGI_TRTLLM_BACKEND_FFI_H +#define TGI_TRTLLM_BACKEND_FFI_H + +#include +#include "backend.h" + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl; +} + +#include "backends/trtllm/src/lib.rs.h" + + +namespace huggingface::tgi::backends { + +// struct GenerationContext; + + class TensorRtLlmBackendImpl : public TensorRtLlmBackend { + public: + /*** + * + * @param engineFolder + * @param executorWorker + */ + TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); + + /*** + * + * @return + */ + bool IsReady() const; + + /*** + * + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] + uint64_t + Submit(rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); + + /*** + * + * @param requestId + * @param ctx + * @param callback + * @return + */ + size_t StreamTokens( + const RequestId requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); +} + +#endif //TGI_TRTLLM_BACKEND_FFI_H diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h new file mode 100644 index 00000000..da0bf4f3 --- /dev/null +++ b/backends/trtllm/include/hardware.h @@ -0,0 +1,59 @@ +// +// Created by mfuntowicz on 7/23/24. +// + +#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H +#define TGI_TRTLLM_BACKEND_HARDWARE_H + +#include +#include +#include +#include +#include + +namespace huggingface::hardware::cuda { + +#define AMPERE_SM_MAJOR 8 +#define HOPPER_SM_MAJOR 8 + + /** + * Store information about the version of the CUDA Compute Capabilities detected on the device + */ + struct CudaComputeCapabilities { + int32_t major; + int32_t minor; + + [[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; } + + [[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; } + }; + + CudaComputeCapabilities GetCudaComputeCapabilities() { + // Get the compute capabilities of the current hardware + nvmlDevice_t device; + CudaComputeCapabilities capabilities{0, 0}; + if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) { + SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0"); + if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) { + SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor); + } + } + + return capabilities; + } + + /** + * Return the number of GPU detected. If no GPU is detected, return size_t::max() + * @return + */ + std::optional GetNumDevices() { + uint32_t numGpus = 0; + if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) { + return std::optional(numGpus); + } else { + return std::nullopt; + } + } +} + +#endif //TGI_TRTLLM_BACKEND_HARDWARE_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp new file mode 100644 index 00000000..c066a6d6 --- /dev/null +++ b/backends/trtllm/lib/backend.cpp @@ -0,0 +1,146 @@ +#include + +#include +#include +#include + +#include "backend.h" +#include "hardware.h" + +void huggingface::tgi::backends::InitializeBackend() { + SPDLOG_INFO("Initializing Backend..."); + nvmlInit_v2(); + initTrtLlmPlugins(); + + const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); + if (numGpus.has_value()) { + SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); + } else { + SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system"); + } +} + +[[nodiscard]] +tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { + tle::ExecutorConfig execConfig(1); + + // Retrieve the compute capabilities to enable some options at runtime + const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); + + // Single engine (TP = PP = 1) -> using leader mode (no MPI involved) + if (config["/pretrained_config/mapping/world_size"_json_pointer].get() == 1) { + SPDLOG_INFO("Detected single engine deployment, using leader mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kLEADER, + std::nullopt, + std::nullopt, + std::nullopt + )); + } else { // Multiple engines -> using orchestrator mode (MPI involved) + SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kORCHESTRATOR, + std::nullopt, + std::nullopt, + tle::OrchestratorConfig(true, workerPath, nullptr, true) + )); + } + + // Define some configuration variables + execConfig.setKvCacheConfig(tle::KvCacheConfig(true)); + execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere()); + return execConfig; +} + +tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed) { + return tle::SamplingConfig( + 1, // TGI only use a single beam + topK, + topP, + std::nullopt, + std::nullopt, + std::nullopt, + seed, + temperature, + temperature, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty + ); +} + +huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( + const std::filesystem::path &enginesFolder, + const std::filesystem::path &executorWorker +) : + config(json::parse(std::ifstream(enginesFolder / "config.json"))), + executor( + enginesFolder, + tensorrt_llm::executor::ModelType::kDECODER_ONLY, + GetExecutorConfig(config, executorWorker.string() + )) { + SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); +} + +bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { + return executor.canEnqueueRequests(); +} + +[[nodiscard("Returned number of requests needs to be consumed")]] +size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { + return executor.getNumResponsesReady(); +} + +[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] +tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( + const std::vector &tokens, + const int32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed +) { +#ifdef NDEBUG + SPDLOG_DEBUG( + FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), + tokens.size(), + executor.getLatestIterationStats().back().numActiveRequests + ); +#else + SPDLOG_DEBUG( + FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), + fmt::join(tokens, ", "), + executor.getLatestIterationStats().front().numActiveRequests + ); +#endif + + const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + const auto maxNewTokens = static_cast(std::max(1ul, maxNumTokens - tokens.size())); + + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); + const auto output = tle::OutputConfig(true, false, false, true, false); + return executor.enqueueRequest( + tle::Request{tokens, maxNewTokens, true, sampling, output}); +} + +[[nodiscard("Generated tokens result must be used")]] +std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { + SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId); + return executor.awaitResponses(requestId); +} + + +void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { + SPDLOG_INFO("Shutting down executor"); + executor.shutdown(); +} diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh new file mode 100755 index 00000000..e0e2dd17 --- /dev/null +++ b/backends/trtllm/scripts/install_tensorrt.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +set -ex + +TRT_VER="10.2.0.19" +CUDA_VER="12.5" +CUDNN_VER="9.2.1.18-1" +NCCL_VER="2.22.3-1+cuda12.5" +CUBLAS_VER="12.5.3.2-1" +NVRTC_VER="12.5.82-1" + +for i in "$@"; do + case $i in + --TRT_VER=?*) TRT_VER="${i#*=}";; + --CUDA_VER=?*) CUDA_VER="${i#*=}";; + --CUDNN_VER=?*) CUDNN_VER="${i#*=}";; + --NCCL_VER=?*) NCCL_VER="${i#*=}";; + --CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";; + *) ;; + esac + shift +done + +NVCC_VERSION_OUTPUT=$(nvcc --version) +if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then + echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}." + exit 1 +fi + +install_ubuntu_requirements() { + apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates + ARCH=$(uname -m) + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi + curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb + dpkg -i cuda-keyring_1.0-1_all.deb + + apt-get update + if [[ $(apt list --installed | grep libcudnn9) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcudnn9* + fi + if [[ $(apt list --installed | grep libnccl) ]]; then + apt-get remove --purge -y --allow-change-held-packages libnccl* + fi + if [[ $(apt list --installed | grep libcublas) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcublas* + fi + if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then + apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* + fi + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER} + apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER} + apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} + # NVRTC static library doesn't exist in NGC PyTorch container. + NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} + apt-get clean + rm -rf /var/lib/apt/lists/* +} + +install_centos_requirements() { + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + yum -y update + yum -y install epel-release + yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER} + yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} + yum clean all +} + +install_tensorrt() { + #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') + #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") + TRT_CUDA_VERSION="12.5" + + if [ -z "$RELEASE_URL_TRT" ];then + ARCH=${TRT_TARGETARCH} + if [ -z "$ARCH" ];then ARCH=$(uname -m);fi + if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi + if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz + fi + wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar + tar -xf /tmp/TensorRT.tar -C /usr/local/ + mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt + # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl + rm -rf /tmp/TensorRT.tar +} + +# Install base packages depending on the base OS +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +case "$ID" in + debian) + install_ubuntu_requirements + install_tensorrt + ;; + ubuntu) + install_ubuntu_requirements + install_tensorrt + ;; + centos) + install_centos_requirements + install_tensorrt + ;; + *) + echo "Unable to determine OS..." + exit 1 + ;; +esac diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs new file mode 100644 index 00000000..b26d06a6 --- /dev/null +++ b/backends/trtllm/src/backend.rs @@ -0,0 +1,329 @@ +use std::future::Future; +use std::path::Path; +use std::pin::{pin, Pin}; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; + +use async_trait::async_trait; +use cxx::UniquePtr; +use log::{error, warn}; +use tokenizers::Tokenizer; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::RwLock; +use tokio::time::{sleep, Instant}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::{Stream, StreamExt}; +use tracing::{instrument, span, Level}; + +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidationError::UnsupportedModality; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; +use text_generation_router::{FinishReason, Token}; + +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; + +// Value used to poll the state of the generation stream +static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); + +type InferResult = Result; + +pub(crate) struct Generation { + executor: Arc>>, + done: Arc, +} + +/// Holds the user provided input to be executed along with a channel allowing +/// to bubble up all the generated tokens for that tokens the to end stream. +pub struct GenerationContext { + sender: UnboundedSender>, + tokenizer: Arc, + tokens: Vec, + done: Arc, + queued: Instant, + start: Option, +} + +impl Stream for Generation { + type Item = usize; + + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let interval = POLLING_INTERVAL_US.get_or_init(|| { + u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100")) + .expect("Invalid value provided for envvar POLLING_INTERVAL_US") + }); + + if !self.done.load(Ordering::Relaxed) { + let backend = pin!(self.executor.read()); + let status = match backend.poll(ctx) { + Poll::Ready(executor_r) => { + let ready = executor_r.num_responses_ready(); + if ready == 0 { + Poll::Pending + } else { + Poll::Ready(Some(ready)) + } + } + Poll::Pending => Poll::Pending, + }; + + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_micros(*interval)).await; + waker.wake(); + }); + + status + } else { + Poll::Ready(None) // end of stream + } + } + + fn size_hint(&self) -> (usize, Option) { + (1, None) + } +} + +unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} + +/// Implements the logic to execute generation with TensorRT-LLM executor API in background +pub struct TensorRtLlmBackend { + tokenizer: Arc, + + // Backing the backend behind a RwLock to allow concurrent read access to retrieve + // the number of available tokens (read only) in the Generation stream + backend: Arc>>, +} + +impl TensorRtLlmBackend { + pub fn new + Send + 'static, PP: AsRef + Send + 'static>( + tokenizer: Tokenizer, + engine_folder: P, + executor_worker_path: PP, + ) -> Result { + Ok(TensorRtLlmBackend { + tokenizer: Arc::new(tokenizer), + backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( + engine_folder.as_ref().to_str().unwrap(), + executor_worker_path.as_ref().to_str().unwrap(), + ))), + }) + } + + fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + if request.top_n_tokens > 1 { + return Err(InferError::ValidationError( + ValidationError::TopNTokensDisabled, + )); + } + + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(InferError::ValidationError(ValidationError::Grammar)); + } + + match request.inputs.len() { + 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 2.. => Err(InferError::GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(text) => Ok(text), + Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + }, + } + } + + fn generate( + &self, + sender: UnboundedSender>, + tokens: Vec, + top_k: u32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) { + let tokenizer = Arc::clone(&self.tokenizer); + let executor = Arc::clone(&self.backend); + + // Let's push this in async context + tokio::spawn(async move { + // Define the generation state + let mut generation = Generation { + executor: executor.clone(), + done: Arc::new(AtomicBool::new(false)), + }; + + // Define the context over the generation + // TODO(asap): Do we really need so many shared-ownership? + let ctx = Box::new(GenerationContext { + sender: sender.clone(), + tokenizer, + tokens: vec![], + done: Arc::clone(&generation.done), + start: None, + queued: Instant::now(), + }); + + // We are leaking the context on-purpose to avoid the box being dropped while there are + // still computation ongoing + // TODO(asap): Can we achieve the same with an Arc> without the need to go unsafe? + let ctx_ = Box::leak(ctx); + + // Submit the request to the batcher + let request_id = span!(Level::DEBUG, "submit") + .in_scope(|| async { + let mut handle = executor.write().await; + let request_id = handle.pin_mut().submit( + &tokens, + top_k as i32, + top_p, + temperature, + repetition_penalty, + frequency_penalty, + seed, + ); + + request_id + }) + .await; + + while let Some(_) = generation.next().await { + let mut executor_w = executor.write().await; + let executor = executor_w.pin_mut(); + + span!(Level::DEBUG, "decode") + .in_scope(|| async { + unsafe { + executor.stream_tokens( + request_id, + ctx_, + |ctx: *mut GenerationContext, step: GenerationStep| { + let inner_ctx = &mut *ctx; + + // Update the timestamp at which the request started effectively + // Can be a bit off, would need to be before the callback, let's see + inner_ctx.start.get_or_insert(Instant::now()); + inner_ctx.done.store(step.is_final, Ordering::Relaxed); + + // Ensure we are not running into errors + let parcel = if !step.has_error { + // Insert the latest generated token to the tracker + inner_ctx.tokens.push(step.token_id); + + // Decode the token + let text = inner_ctx + .tokenizer + .decode(&[step.token_id], true) + .expect("Failed to decode token"); + + let special = inner_ctx + .tokenizer + .get_added_vocabulary() + .is_special_token(&text); + + // Create the structure holding the token + let token = Token { + id: step.token_id, + text, + logprob: step.log_prob, + special, + }; + + if step.is_final { + let generated_text = inner_ctx + .tokenizer + .decode(&inner_ctx.tokens, true) + .expect("Failed to decode generated_tokens"); + + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: generated_text, + generated_tokens: inner_ctx.tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: inner_ctx.start.unwrap_or(Instant::now()), + queued: inner_ctx.queued, + }) + } else { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } + } else { + error!("Error caught while decoding: {}", &step.error_msg); + Err(InferError::GenerationError(step.error_msg)) + }; + + // Send the parcel to the client + inner_ctx + .sender + .send(parcel) + .expect("Failed to sent msg through the channel"); + }, + ); + } + }) + .await; + } + + // "Properly" free the shared context... + // TODO: clean that piece of sh** asap + unsafe { + let _ = Box::from_raw(ctx_); + } + }); + } +} + +#[async_trait] +impl Backend for TensorRtLlmBackend { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> InferResult>> { + // Let's add a few more validation + let input = TensorRtLlmBackend::validate(&request)?; + + // Channel to stream the generated token as they come from the worker thread back to the transport layer + let (sender, receiver) = unbounded_channel(); + + // Unpack parameters + let params = &request.parameters; + + // Preprocess the inputs to send to TRTLLM backend + let encoding = self + .tokenizer + .encode(input.as_str(), true) + .map_err(|e| InferError::GenerationError(e.to_string()))?; + + // Generate the response + self.generate( + sender, + Vec::from(encoding.get_ids()), + params.top_k, + params.top_p, + params.temperature, + params.repetition_penalty, + params.frequency_penalty, + params.seed, + ); + + Ok(UnboundedReceiverStream::new(receiver)) + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 00000000..a672d2a4 --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +use text_generation_router::server; + +#[derive(Debug, Error)] +pub enum TensorRtLlmBackendError { + #[error("Tokenizer error: {0}")] + Tokenizer(String), + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp new file mode 100644 index 00000000..d6317a68 --- /dev/null +++ b/backends/trtllm/src/ffi.cpp @@ -0,0 +1,84 @@ +// +// Created by mfuntowicz on 6/30/24. +// +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include "backends/trtllm/include/ffi.h" + + +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl( + const std::string_view &engineFolder, + const std::string_view &executorWorker +) : TensorRtLlmBackend(engineFolder, executorWorker) {} + + +bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { + return TensorRtLlmBackend::IsReady(); +} + +uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( + rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed) { + + // This will copy all the items from the initial slice + std::vector tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); + return TensorRtLlmBackend::Submit( + std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); +} + +size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( + const uint64_t requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback) { + + size_t numTokens = 0; + for (const auto &item: Poll(requestId)) { + GenerationStep step; + if (!item.hasError()) { + SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); + const auto decoded = item.getResult(); + + const auto token = decoded.outputTokenIds[0][0]; + const auto isFinal = decoded.isFinal; + const auto logProb = decoded.logProbs.value()[0][0]; + + ++numTokens; + + SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); + step = huggingface::tgi::backends::GenerationStep{ + static_cast(token), logProb, isFinal, false, std::move(std::string()) + }; + SPDLOG_DEBUG("\tStreamTokens -> Post callback"); + } else { + // TODO : Return rest::Result with error + const auto what = item.getErrorMsg(); + SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); + step = huggingface::tgi::backends::GenerationStep{ + std::numeric_limits::max(), 0.0, true, true, std::move(what) + }; + } + + callback(std::move(ctx), std::move(step)); + } + + return numTokens; +} + +std::unique_ptr +huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + // Unconditionally call this to initialize and discover TRTLLM plugins + InitializeBackend(); + + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); + return std::make_unique(std::move(enginePath), std::move(executorPath)); +} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 00000000..1a804f88 --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,78 @@ +pub use backend::{GenerationContext, TensorRtLlmBackend}; + +mod backend; +pub mod errors; + +#[cxx::bridge(namespace = "huggingface::tgi::backends")] +mod ffi { + + /// Struct used as shared type between rust and C++ to represent the result + /// of a single decoding iteration + pub struct GenerationStep { + token_id: u32, + log_prob: f32, + is_final: bool, + has_error: bool, + error_msg: String, + } + + extern "Rust" { + type GenerationContext; + } + + unsafe extern "C++" { + include!("backends/trtllm/src/ffi.cpp"); + + /// Represent an instance of the underlying TensorRT-LLM backend + type TensorRtLlmBackendImpl; + + /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend + /// + /// # Arguments + /// + /// * `engine_folder`: Path to the folder containing all the TRTLLM engines + /// * `executor_worker`: Path to the TRTLLM executor worker + /// + /// returns: + /// + /// # Examples + /// + /// ``` + /// + /// ``` + #[rust_name = "create_tensorrt_llm_backend"] + fn CreateTensorRtLlmBackend( + engine_folder: &str, + executor_worker: &str, + ) -> UniquePtr; + + // #[rust_name = "is_ready"] + // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + + #[rust_name = "num_responses_ready"] + fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; + + #[rust_name = "submit"] + fn Submit( + self: Pin<&mut TensorRtLlmBackendImpl>, + tokens: &[u32], + top_k: i32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) -> u64; + + #[rust_name = "stream_tokens"] + unsafe fn StreamTokens( + self: Pin<&mut TensorRtLlmBackendImpl>, + request_id: u64, + ctx: *mut GenerationContext, + cb: unsafe fn(*mut GenerationContext, GenerationStep), + ) -> usize; + + // #[rust_name = "shutdown"] + // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); + } +} diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 00000000..6d6ee146 --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,166 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use clap::Parser; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; +use text_generation_backends_trtllm::TensorRtLlmBackend; +use text_generation_router::server; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(long, env, required = true)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env)] + model_id: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env)] + auth_token: Option, + #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] + executor_worker: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<(), TensorRtLlmBackendError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, + hostname, + port, + tokenizer_name, + tokenizer_config_path, + revision, + model_id, + validation_workers, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + messages_api_enabled, + max_client_batch_size, + auth_token, + executor_worker, + } = args; + + // Launch Tokio runtime + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if !executor_worker.exists() { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!( + "`executor_work` specified path doesn't exists: {}", + executor_worker.display() + ))); + } + + // Run server + let tokenizer = Tokenizer::from_pretrained( + tokenizer_name.clone(), + Some(FromPretrainedParameters { + revision: revision.clone().unwrap_or(String::from("main")), + user_agent: HashMap::new(), + auth_token, + }), + ) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + + let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + None, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + false, + None, + None, + messages_api_enabled, + true, + max_client_batch_size, + ) + .await?; + Ok(()) +} diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp new file mode 100644 index 00000000..8520065a --- /dev/null +++ b/backends/trtllm/tests/infer_test.cpp @@ -0,0 +1,14 @@ +// +// Created by mfuntowicz on 7/2/24. +// +#include +#include +#include "../include/backend.h" + +TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") { + const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/"); + const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker"); + + spdlog::info("Loading config from: {}", absolute(engines).string()); + huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor); +} diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml new file mode 100644 index 00000000..5d9a140b --- /dev/null +++ b/backends/v3/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "text-generation-router-v3" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +thiserror = "1.0.48" +tokenizers = { workspace = true} +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = { version = "2.0.2" } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] diff --git a/backends/v3/build.rs b/backends/v3/build.rs new file mode 100644 index 00000000..6d702d14 --- /dev/null +++ b/backends/v3/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs new file mode 100644 index 00000000..49e2bc8f --- /dev/null +++ b/backends/v3/src/backend.rs @@ -0,0 +1,501 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; +/// Batching and inference logic +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + let queue = Queue::new( + requires_padding, + 16, + window_size, + speculate, + max_batch_total_tokens, + ); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + err + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/router/src/infer/v3/block_allocator.rs b/backends/v3/src/block_allocator.rs similarity index 100% rename from router/src/infer/v3/block_allocator.rs rename to backends/v3/src/block_allocator.rs diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs new file mode 100644 index 00000000..c407687b --- /dev/null +++ b/backends/v3/src/client/grpc_client.rs @@ -0,0 +1,284 @@ +/// Single shard Client +use crate::client::{pb, Chunk}; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs new file mode 100644 index 00000000..755431f4 --- /dev/null +++ b/backends/v3/src/client/mod.rs @@ -0,0 +1,76 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs new file mode 100644 index 00000000..afb13cdc --- /dev/null +++ b/backends/v3/src/client/sharded_client.rs @@ -0,0 +1,260 @@ +use crate::client::{ClientError, Result}; +/// Multi shard Client +use crate::client::{Health, ShardInfo}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use crate::client::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + adapter_id: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs new file mode 100644 index 00000000..a6f89169 --- /dev/null +++ b/backends/v3/src/lib.rs @@ -0,0 +1,142 @@ +mod backend; +mod block_allocator; +mod client; +mod queue; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV3; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV3, BackendInfo), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs new file mode 100644 index 00000000..ef10514f --- /dev/null +++ b/backends/v3/src/main.rs @@ -0,0 +1,208 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::server; +use text_generation_router_v3::{connect_backend, V3Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env, default_value_t)] + disable_usage_stats: bool, + #[clap(long, env, default_value_t)] + disable_crash_reports: bool, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + disable_usage_stats, + disable_crash_reports, + max_client_batch_size, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V3Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/router/src/infer/v3/queue.rs b/backends/v3/src/queue.rs similarity index 95% rename from router/src/infer/v3/queue.rs rename to backends/v3/src/queue.rs index 894d9cab..9427bd60 100644 --- a/router/src/infer/v3/queue.rs +++ b/backends/v3/src/queue.rs @@ -1,17 +1,17 @@ -use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +use crate::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::client; +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::{max, min}; use std::collections::VecDeque; -use text_generation_client::v3::{ - Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, + ValidStoppingParameters, }; -use text_generation_client::ChunksToString; -use text_generation_client::Input; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -337,8 +337,22 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - input_chunks: Some(Input { - chunks: entry.request.inputs.clone(), + input_chunks: Some(client::Input { + chunks: entry + .request + .inputs + .clone() + .into_iter() + .map(|c| client::InputChunk { + chunk: Some(match c { + Chunk::Text(text) => client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { + data: image.data, + mimetype: image.mimetype, + }), + }), + }) + .collect(), }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 756460e0..f82659c9 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -21,7 +21,7 @@ float-ord = "0.3.2" serde = {version = "1.0.188", features = ["derive"]} serde_json = "1.0" tabled = "0.14.0" -text-generation-client = { path = "../router/client" } +text-generation-client = { path = "../backends/client" } thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } diff --git a/docs/openapi.json b/docs/openapi.json index 873044a7..4d3a2398 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1580,16 +1580,11 @@ "type": "object", "required": [ "model_id", - "model_dtype", - "model_device_type", "max_concurrent_requests", "max_best_of", "max_stop_sequences", "max_input_tokens", "max_total_tokens", - "waiting_served_ratio", - "max_batch_total_tokens", - "max_waiting_tokens", "validation_workers", "max_client_batch_size", "router", @@ -1601,18 +1596,6 @@ "example": "null", "nullable": true }, - "max_batch_size": { - "type": "integer", - "example": "null", - "nullable": true, - "minimum": 0 - }, - "max_batch_total_tokens": { - "type": "integer", - "format": "int32", - "example": "32000", - "minimum": 0 - }, "max_best_of": { "type": "integer", "example": "2", @@ -1644,19 +1627,6 @@ "example": "2048", "minimum": 0 }, - "max_waiting_tokens": { - "type": "integer", - "example": "20", - "minimum": 0 - }, - "model_device_type": { - "type": "string", - "example": "cuda" - }, - "model_dtype": { - "type": "string", - "example": "torch.float16" - }, "model_id": { "type": "string", "description": "Model info", @@ -1690,11 +1660,6 @@ "version": { "type": "string", "example": "0.5.0" - }, - "waiting_served_ratio": { - "type": "number", - "format": "float", - "example": "1.2" } } }, diff --git a/router/Cargo.toml b/router/Cargo.toml index 0fc700a0..1be74546 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -7,25 +7,18 @@ edition.workspace = true authors.workspace = true homepage.workspace = true -[lib] -path = "src/lib.rs" - -[[bin]] -name = "text-generation-router" -path = "src/main.rs" - [dependencies] +async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.23.0" -metrics-exporter-prometheus = { version = "0.15.1", features = [] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.13.0" @@ -55,6 +48,7 @@ base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } csv = "1.3.0" +ureq = "=2.9" [build-dependencies] diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/chat_template.rs similarity index 67% rename from router/src/infer/v3/scheduler.rs rename to router/src/infer/chat_template.rs index 26cd9584..24a00352 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/chat_template.rs @@ -1,528 +1,85 @@ -/// Batching and inference logic -use crate::infer::v3::queue::{Entry, Queue}; -use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, +use crate::infer::InferError; +use crate::{ + ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, }; -use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::ClientError; -use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; -use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{info_span, instrument, Instrument, Span}; +use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; -pub(crate) struct SchedulerV3 { - /// Request queue - queue: Queue, - /// Notify batcher on queue appends - batching_task_notifier: Arc, +/// Raise a exception (custom function) used in the chat templates +pub(crate) fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) } -impl SchedulerV3 { - #[allow(clippy::too_many_arguments)] +#[derive(Clone)] +pub(crate) struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { pub(crate) fn new( - client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, + template: String, + bos_token: Option, + eos_token: Option, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") - } else { - false - }; - let block_size = if flashdecoding { 256 } else { 16 }; - let queue = Queue::new( - requires_padding, - block_size, - window_size, - speculate, - max_batch_total_tokens, - ); - let batching_task_notifier = Arc::new(Notify::new()); + let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - batching_task_notifier.clone(), - generation_health, - )); + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); Self { - queue, - batching_task_notifier, + template, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), + use_default_tool_template, } } -} -impl Scheduler for SchedulerV3 { - #[instrument(skip_all)] - fn schedule( + pub(crate) fn apply( &self, - request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result { - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - block_allocation: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.batching_task_notifier.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[allow(clippy::too_many_arguments)] -pub(crate) async fn batching_task( - mut client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - queue: Queue, - notifier: Arc, - generation_health: Arc, -) { - // Infinite loop - loop { - // Wait for a notification from the Infer struct - notifier.notified().await; - - // Get the next batch from the queue - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the queue - while let Some((mut entries, batch, span)) = queue - .next_batch( - None, - max_batch_size, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) - .await - { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) - .await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let batch_max_tokens = batch.max_tokens; - let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None - } else { - // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; - - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); - - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) - .await - { - // Tracking metrics - if min_size.is_some() { - metrics::counter!("tgi_batch_concat", "reason" => "backpressure") - .increment(1); - } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); - } - - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } } - - // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); - let next_batch_span = - info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); - - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) - .instrument(next_batch_span) - .await; - waiting_tokens += 1; - } - metrics::gauge!("tgi_batch_current_size").set(0.0); - metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); - } - } -} - -#[instrument(skip_all)] -async fn prefill( - client: &mut ShardedClient, - batch: Batch, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_id = batch.id; - metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - - match client.prefill(batch).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - // Update health - generation_health.store(false, Ordering::SeqCst); - let _ = client.clear_cache(Some(batch_id)).await; - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); - None - } - } -} - -#[instrument(skip_all)] -async fn decode( - client: &mut ShardedClient, - batches: Vec, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - - match client.decode(batches).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") - .record(concat_duration.as_secs_f64()); - } - metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - generation_health.store(false, Ordering::SeqCst); - for id in batch_ids { - let _ = client.clear_cache(Some(id)).await; - } - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); - None - } - } -} - -/// Filter a `batch` and remove all requests not present in `entries` -#[instrument(skip_all)] -async fn filter_batch( - client: &mut ShardedClient, - next_batch: Option, - entries: &IntMap, -) -> Option { - let mut batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() { - return Some(batch); - } - - let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { - // All requests have been filtered out - // Next batch is now empty - // Clear it from the Python shards cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.clear_cache(Some(id)).await.unwrap(); - None - } else { - // Filter Python shard cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() - } -} - -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries -#[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { - let id = generation.request_id; - // Get entry - // We can `expect` here as the request id should always be in the entries - let entry = entries - .get(&id) - .expect("ID not found in entries. This is a bug."); - - // Create and enter a span to link this function back to the entry - let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - err - }).unwrap_or(true); - if stopped { - entries.remove(&id).expect("ID not found in entries. This is a bug."); - } - }); -} - -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - return Ok(true); - } - - let mut stopped = false; - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - let prefill_tokens = prefill_tokens - .ids - .into_iter() - .zip(prefill_tokens.logprobs) - .zip(prefill_tokens.texts) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; - } - - // Create last Token - let tokens_ = generation.tokens.expect("Non empty tokens in generation"); - let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); - let mut iterator = tokens_ - .ids - .into_iter() - .zip(tokens_.logprobs) - .zip(tokens_.texts) - .zip(tokens_.is_special) - .enumerate() - .peekable(); - while let Some((i, (((id, logprob), text), special))) = iterator.next() { - let token = Token { - id, - text, - logprob, - special, - }; - let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { - top_tokens_ - .ids - .iter() - .zip(top_tokens_.logprobs.iter()) - .zip(top_tokens_.texts.iter()) - .zip(top_tokens_.is_special.iter()) - .map(|(((&id, &logprob), text), &special)| Token { - id, - text: text.to_string(), - logprob, - special, - }) - .collect() - } else { - vec![] - }; - match (&generation.generated_text, iterator.peek()) { - (Some(generated_text), None) => { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text: GeneratedText::from(generated_text.clone()), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } - _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; } } - } - Ok(stopped) -} + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); -/// Send errors to Infer for all `entries` -#[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); - metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); - }); -} - -impl From for GeneratedText { - fn from(value: text_generation_client::v3::GeneratedText) -> Self { - let v3_finish_reason = - text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); - let finish_reason = match v3_finish_reason { - text_generation_client::v3::FinishReason::Length => FinishReason::Length, - text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, - }; - - Self { - text: value.text, - generated_tokens: value.generated_tokens, - finish_reason, - seed: value.seed, - } + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) } } // tests #[cfg(test)] mod tests { - use crate::infer::raise_exception; + use crate::infer::chat_template::raise_exception; use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; diff --git a/router/src/infer/health.rs b/router/src/infer/health.rs deleted file mode 100644 index 4320c1a4..00000000 --- a/router/src/infer/health.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use text_generation_client::Health; - -#[derive(Clone)] -pub(crate) struct HealthCheck { - client: Arc, - generation_health: Arc, -} - -impl HealthCheck { - pub(crate) fn new( - client: Arc, - generation_health: Arc, - ) -> Self { - Self { - client, - generation_health, - } - } - - pub(crate) async fn check(&mut self) -> bool { - let value = if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards can allocate on device - self.client.device_health().await - } else { - self.client.model_health().await - } - .is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value - } -} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index db9070d4..534a2647 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,23 +1,18 @@ -mod health; -pub(crate) mod v2; -pub(crate) mod v3; - -pub(crate) use health::HealthCheck; +// pub(crate) mod v2; +mod chat_template; +pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; +use crate::GrammarType; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, -}; -use crate::{ - FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, + ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, + Message, PrefillToken, Token, }; +use async_trait::async_trait; +use chat_template::ChatTemplate; use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; -use minijinja_contrib::pycompat; - -use serde_json::{json, Map, Value}; -use std::collections::HashMap; +use minijinja::ErrorKind; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; @@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; -pub(crate) trait Scheduler { +#[async_trait] +pub trait Backend { fn schedule( &self, request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result; + ) -> Result>, InferError>; + + async fn health(&self, current_health: bool) -> bool; } /// Inference struct @@ -39,18 +36,20 @@ pub(crate) trait Scheduler { pub struct Infer { /// Validation validation: Validation, - /// Request scheduler - scheduler: Arc, + /// Request backend + backend: Arc, /// Chat template chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, + /// Backend health + backend_health: Arc, } impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - scheduler: Arc, + backend: impl Backend + Send + Sync + 'static, validation: Validation, max_concurrent_requests: usize, tokenizer_config: HubTokenizerConfig, @@ -71,18 +70,22 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + // Backend health + let backend_health = Arc::new(AtomicBool::new(false)); + Self { validation, - scheduler, + backend: Arc::new(backend), chat_template, limit_concurrent_requests: semaphore, + backend_health, } } /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] - pub(crate) async fn generate_stream( - &self, + pub(crate) async fn generate_stream<'a>( + &'a self, request: GenerateRequest, ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore @@ -103,7 +106,10 @@ impl Infer { err })?; - self.scheduler.schedule(valid_request, permit) + let input_length = valid_request.input_length; + let generation_stream = self.backend.schedule(valid_request)?; + + Ok((permit, input_length, generation_stream)) } /// Tokenizer the input @@ -155,7 +161,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -165,6 +171,8 @@ impl Infer { let mut result_start = None; let mut result_queued = None; + let mut stream = Box::pin(stream); + // Iterate on stream while let Some(response) = stream.next().await { match response? { @@ -256,207 +264,15 @@ impl Infer { let best_response = infer_responses.remove(max_index); Ok((best_response, infer_responses)) } -} -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new( - template: String, - bos_token: Option, - eos_token: Option, - ) -> Self { - let mut env = Box::new(Environment::new()); - // enable things like .strip() or .capitalize() - env.set_unknown_method_callback(pycompat::unknown_method_callback); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token: bos_token.map(|token| token.as_str().to_string()), - eos_token: eos_token.map(|token| token.as_str().to_string()), - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - // find a tool by name - fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { - tools - .iter() - .find(|tool| tool.function.name == name) - .cloned() - .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) - } - - pub fn apply( - tools: Option>, - tool_choice: ToolChoice, - ) -> Result, InferError> { - // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; - - let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - - // if tools are provided and no tool_choice we default to the OneOf - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![Self::find_tool_by_name(&tools, &name)?] - } - ToolType::Function { function } => { - vec![Self::find_tool_by_name(&tools, &function.name)?] - } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - Ok(Some(tools)) + #[instrument(skip(self))] + pub(crate) async fn health(&self) -> bool { + let health = self + .backend + .health(self.backend_health.load(Ordering::SeqCst)) + .await; + self.backend_health.store(health, Ordering::SeqCst); + health } } @@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = ( ); #[derive(Debug)] -pub(crate) struct GeneratedText { - pub(crate) text: String, - pub(crate) generated_tokens: u32, - pub(crate) finish_reason: FinishReason, - pub(crate) seed: Option, +pub struct GeneratedText { + pub text: String, + pub generated_tokens: u32, + pub finish_reason: FinishReason, + pub seed: Option, } #[derive(Debug)] -pub(crate) enum InferStreamResponse { +pub enum InferStreamResponse { // Optional first message Prefill(Vec), // Intermediate messages diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs new file mode 100644 index 00000000..05027f30 --- /dev/null +++ b/router/src/infer/tool_grammar.rs @@ -0,0 +1,135 @@ +use crate::infer::InferError; +use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; + +pub(crate) struct ToolGrammar {} + +impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + + pub fn apply( + tools: Option>, + tool_choice: ToolChoice, + ) -> Result, InferError> { + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; + + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) + } +} diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index 8b4f6bab..6a91a433 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,4 +1,4 @@ mod queue; mod scheduler; -pub(crate) use scheduler::SchedulerV2; +pub(crate) use scheduler::BackendV2; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 97379bc5..3d6c36cf 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, + Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; @@ -18,14 +18,14 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -pub(crate) struct SchedulerV2 { +pub(crate) struct BackendV2 { /// Request queue queue: Queue, /// Notify batcher on queue appends batching_task_notifier: Arc, } -impl SchedulerV2 { +impl BackendV2 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, @@ -69,7 +69,7 @@ impl SchedulerV2 { } } -impl Scheduler for SchedulerV2 { +impl Backend for BackendV2 { #[instrument(skip_all)] fn schedule( &self, diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs deleted file mode 100644 index f9effab8..00000000 --- a/router/src/infer/v3/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod block_allocator; -mod queue; -mod scheduler; - -pub(crate) use scheduler::SchedulerV3; diff --git a/router/src/lib.rs b/router/src/lib.rs index b6e0d09d..14bb8270 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,12 @@ /// Text Generation Inference Webserver pub mod config; -mod infer; +pub mod infer; pub mod server; -mod validation; +pub mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod logging; pub mod usage_stats; @@ -148,12 +149,13 @@ pub struct Info { pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] pub model_sha: Option, - #[schema(example = "torch.float16")] - pub model_dtype: String, - #[schema(example = "cuda")] - pub model_device_type: String, + // #[schema(example = "torch.float16")] + // pub model_dtype: String, + // #[schema(example = "cuda")] + // pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, + /// Router Parameters #[schema(example = "128")] pub max_concurrent_requests: usize, @@ -165,18 +167,11 @@ pub struct Info { pub max_input_tokens: usize, #[schema(example = "2048")] pub max_total_tokens: usize, - #[schema(example = "1.2")] - pub waiting_served_ratio: f32, - #[schema(example = "32000")] - pub max_batch_total_tokens: u32, - #[schema(example = "20")] - pub max_waiting_tokens: usize, - #[schema(nullable = true, example = "null")] - pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, #[schema(example = "32")] pub max_client_batch_size: usize, + /// Router Info #[schema(example = "text-generation-router")] pub router: &'static str, @@ -1068,23 +1063,23 @@ impl From for GenerateRequest { #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, } #[derive(Debug, Serialize, ToSchema, Clone)] pub struct Token { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, #[schema(example = "false")] - special: bool, + pub special: bool, } #[derive(Debug, Serialize, ToSchema)] @@ -1102,7 +1097,7 @@ pub struct SimpleToken { #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] -pub(crate) enum FinishReason { +pub enum FinishReason { #[schema(rename = "length")] Length, #[serde(rename = "eos_token")] diff --git a/router/src/logging.rs b/router/src/logging.rs new file mode 100644 index 00000000..5a98ef57 --- /dev/null +++ b/router/src/logging.rs @@ -0,0 +1,81 @@ +use opentelemetry::sdk::propagation::TraceContextPropagator; +use opentelemetry::sdk::trace; +use opentelemetry::sdk::trace::Sampler; +use opentelemetry::sdk::Resource; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; + +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - otlp_service_name service name to appear in APM +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) +pub fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_ansi(ansi) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer.json().flatten_event(true).boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config( + trace::config() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + otlp_service_name, + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + init_tracing_opentelemetry::init_propagator().unwrap(); + }; + } + + // Filter events with LOG_LEVEL + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} diff --git a/router/src/main.rs b/router/src/main.rs.back similarity index 100% rename from router/src/main.rs rename to router/src/main.rs.back diff --git a/router/src/server.rs b/router/src/server.rs index 0fd5aade..ccbd1535 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,14 +1,13 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::v2::SchedulerV2; -use crate::infer::v3::SchedulerV3; -use crate::infer::{HealthCheck, Scheduler}; -use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +use crate::infer::tool_grammar::ToolGrammar; +use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::usage_stats; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, @@ -27,7 +26,7 @@ use crate::{ use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; +use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -37,15 +36,18 @@ use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; -use text_generation_client::{v2, v3, ClientError, ShardInfo}; +use std::fs::File; +use std::io::BufReader; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; use thiserror::Error; +use tokenizers::processors::template::TemplateProcessing; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -124,12 +126,10 @@ responses( example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument(skip(health))] +#[instrument(skip(infer))] /// Health check method -async fn health( - mut health: Extension, -) -> Result<(), (StatusCode, Json)> { - match health.check().await { +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { + match infer.health().await { true => Ok(()), false => Err(( StatusCode::SERVICE_UNAVAILABLE, @@ -430,8 +430,9 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, mut response_stream)) => { + Ok((_permit, _input_length, response_stream)) => { let mut index = 0; + let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream while let Some(response) = response_stream.next().await { index += 1; @@ -1396,262 +1397,456 @@ async fn metrics(prom_handle: Extension) -> String { #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); +// OpenAPI documentation +#[derive(OpenApi)] +#[openapi( +paths( +health, +get_model_info, +compat_generate, +generate, +generate_stream, +chat_completions, +completions, +tokenize, +metrics, +), +components( +schemas( +Info, +CompatGenerateRequest, +GenerateRequest, +GrammarType, +ChatRequest, +Message, +MessageContent, +MessageChunk, +Url, +FunctionName, +OutputMessage, +TextMessage, +ToolCallMessage, +ToolCallDelta, +ChatCompletionComplete, +ChatCompletionChoice, +ChatCompletionDelta, +ChatCompletionChunk, +ChatCompletionLogprob, +ChatCompletionLogprobs, +ChatCompletionTopLogprob, +ChatCompletion, +CompletionRequest, +CompletionComplete, +Chunk, +Completion, +CompletionFinal, +Prompt, +GenerateParameters, +PrefillToken, +Token, +GenerateResponse, +TokenizeResponse, +SimpleToken, +BestOfSequence, +Details, +FinishReason, +StreamResponse, +StreamDetails, +ErrorResponse, +GrammarType, +Usage, +DeltaToolCall, +ToolType, +Tool, +ToolCall, +Function, +FunctionDefinition, +ToolChoice, +) +), +tags( +(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") +), +info( +title = "Text Generation Inference", +license( +name = "Apache 2.0", +url = "https://www.apache.org/licenses/LICENSE-2.0" +) +) +)] +pub struct ApiDoc; + +pub fn schema() -> ApiDoc { + ApiDoc +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( - master_shard_uds_path: String, - model_info: HubModelInfo, - compat_return_full_text: bool, + backend: impl Backend + Send + Sync + 'static, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, - tokenizer: Option, - config: Option, validation_workers: usize, - addr: SocketAddr, - allow_origin: Option, api_key: Option, + tokenizer_name: String, + tokenizer_config_path: Option, + revision: Option, + hostname: String, + port: u16, + cors_allow_origin: Option>, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - tokenizer_config: HubTokenizerConfig, - preprocessor_config: Option, - processor_config: HubProcessorConfig, messages_api_enabled: bool, - grammar_support: bool, + disable_grammar_support: bool, max_client_batch_size: usize, - print_schema_command: bool, + disable_usage_stats: bool, + disable_crash_reports: bool, ) -> Result<(), WebServerError> { - // OpenAPI documentation - #[derive(OpenApi)] - #[openapi( - paths( - health, - get_model_info, - compat_generate, - generate, - generate_stream, - chat_completions, - completions, - tokenize, - metrics, - ), - components( - schemas( - Info, - CompatGenerateRequest, - GenerateRequest, - GrammarType, - ChatRequest, - Message, - MessageContent, - MessageChunk, - Url, - FunctionName, - OutputMessage, - TextMessage, - ToolCallMessage, - ToolCallDelta, - ChatCompletionComplete, - ChatCompletionChoice, - ChatCompletionDelta, - ChatCompletionChunk, - ChatCompletionLogprob, - ChatCompletionLogprobs, - ChatCompletionTopLogprob, - ChatCompletion, - CompletionRequest, - CompletionComplete, - Chunk, - Completion, - CompletionFinal, - Prompt, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - TokenizeResponse, - SimpleToken, - BestOfSequence, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - GrammarType, - Usage, - DeltaToolCall, - ToolType, - Tool, - ToolCall, - Function, - FunctionDefinition, - ToolChoice, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) - )] - struct ApiDoc; + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); - // Create state - if print_schema_command { - let api_doc = ApiDoc::openapi(); - let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); - println!("{}", api_doc); - std::process::exit(0); + // Parse Huggingface hub token + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + + // Tokenizer instance + // This will only be used to validate payloads + let local_path = Path::new(&tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, } - - // Open connection, get model info and warmup - let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( - Arc, - HealthCheck, - ShardInfo, - u32, - ) = { - // Helper function to check both v2 and v3 - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { - match max_supported_batch_total_tokens { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( - 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), - ); - tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(WebServerError::NotEnoughMemory(max_total_tokens)); - } - - Ok(max_supported_batch_total_tokens) - } - } - }; - - let generation_health = Arc::new(AtomicBool::new(false)); - - match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { - Ok(mut sharded_client) => { - // server is running on v3 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV3::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V3"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - Err(_) => { - let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - - // server is running on v2 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV2::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V2"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) } } + } else { + Type::None }; - tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + // Load tokenizer and model info + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + model_info, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("preprocessor_config.json")), + Some(local_path.join("processor_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("preprocessor_config.json"), + repo.get("processor_config.json"), + None, + ) + } + }; + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer { + if let Some(class) = &tokenizer_config.tokenizer_class { + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ + if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + tokenizer.with_post_processor(post_processor); + } + } + } + } + tokenizer + }); + + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + + let preprocessor_config: Option = + preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); + + tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + + let user_agent = if !disable_usage_stats && is_container { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_config.tokenizer_class.clone(), + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, + revision.clone(), + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } else { + None + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + true + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + let result = start( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + config, + (tokenizer, tokenizer_config), + (preprocessor_config, processor_config), + hostname, + port, + ngrok, + _ngrok_authtoken, + _ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + model_info, + compat_return_full_text, + allow_origin, + ) + .await; + + if let Some(ua) = user_agent { + match result { + Ok(_) => { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + Ok(()) + } + Err(e) => { + if !disable_crash_reports { + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some(e.to_string()), + ); + error_event.send().await; + } else { + let unknow_error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some("unknow_error".to_string()), + ); + unknow_error_event.send().await; + } + Err(e) + } + } + } else { + result + } +} + +#[allow(clippy::too_many_arguments)] +async fn start( + backend: impl Backend + Send + Sync + 'static, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + validation_workers: usize, + api_key: Option, + config: Option, + (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), + (preprocessor_config, processor_config): (Option, HubProcessorConfig), + hostname: String, + port: u16, + ngrok: bool, + _ngrok_authtoken: Option, + _ngrok_edge: Option, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + model_info: HubModelInfo, + compat_return_full_text: bool, + allow_origin: Option, +) -> Result<(), WebServerError> { + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + + // Create state let validation = Validation::new( validation_workers, tokenizer, @@ -1662,11 +1857,11 @@ pub async fn run( max_top_n_tokens, max_input_tokens, max_total_tokens, - grammar_support, + disable_grammar_support, ); let infer = Infer::new( - scheduler, + backend, validation, max_concurrent_requests, tokenizer_config, @@ -1703,8 +1898,8 @@ pub async fn run( let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); // Speculated tokens buckets - let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); - let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); + // let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); + // let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); // Prometheus handler let builder = PrometheusBuilder::new() @@ -1717,9 +1912,9 @@ pub async fn run( .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .unwrap() .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) - .unwrap() - .set_buckets_for_metric(skipped_matcher, &skipped_buckets) .unwrap(); + // .set_buckets_for_metric(skipped_matcher, &skipped_buckets) + // .unwrap(); let prom_handle = builder .install_recorder() .expect("failed to install metrics recorder"); @@ -1735,18 +1930,18 @@ pub async fn run( let info = Info { model_id: model_info.model_id, model_sha: model_info.sha, - model_dtype: shard_info.dtype, - model_device_type: shard_info.device_type, + // model_dtype: shard_info.dtype, + // model_device_type: shard_info.device_type, model_pipeline_tag: model_info.pipeline_tag, max_concurrent_requests, max_best_of, max_stop_sequences, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, + // waiting_served_ratio, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, validation_workers, max_client_batch_size, router: env!("CARGO_PKG_NAME"), @@ -1907,7 +2102,6 @@ pub async fn run( // add layers after routes app = app .layer(Extension(info)) - .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(compute_type)) @@ -1945,6 +2139,68 @@ pub async fn run( Ok(()) } +/// get model info from the Huggingface Hub +pub async fn get_hub_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; + + if response.status().is_success() { + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None + } +} + +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) + .map_err(|e| { + tracing::warn!("Unable to parse tokenizer config: {}", e); + e + }) + .ok()?; + + Some(tokenizer_config) +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async { @@ -2008,16 +2264,77 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { - #[error("Unable to connect to the Python model shards: {0}")] - Connection(ClientError), - #[error("Unable to clear the Python model shards cache: {0}")] - Cache(ClientError), - #[error("Unable to get the Python model shards info: {0}")] - Info(ClientError), - #[error("Unable to warmup the Python model shards: {0}")] - Warmup(ClientError), - #[error("Not enough memory to handle `max_total_tokens={0}`")] - NotEnoughMemory(usize), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), } + +/// Create a post_processor for the LlamaTokenizer +fn create_post_processor( + tokenizer: &Tokenizer, + tokenizer_config: &HubTokenizerConfig, +) -> Result { + let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); + let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); + + let bos_token = tokenizer_config.bos_token.as_ref(); + let eos_token = tokenizer_config.eos_token.as_ref(); + + if add_bos_token && bos_token.is_none() { + panic!("add_bos_token = true but bos_token is None"); + } + + if add_eos_token && eos_token.is_none() { + panic!("add_eos_token = true but eos_token is None"); + } + + let mut single = Vec::new(); + let mut pair = Vec::new(); + let mut special_tokens = Vec::new(); + + if add_bos_token { + if let Some(bos) = bos_token { + let bos_token_id = tokenizer + .token_to_id(bos.as_str()) + .expect("Should have found the bos token id"); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); + } + } + + single.push("$A:0".to_string()); + pair.push("$A:0".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + let eos_token_id = tokenizer + .token_to_id(eos.as_str()) + .expect("Should have found the eos token id"); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); + } + } + + if add_bos_token { + if let Some(bos) = bos_token { + pair.push(format!("{}:1", bos.as_str())); + } + } + + pair.push("$B:1".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + pair.push(format!("{}:1", eos.as_str())); + } + } + + let post_processor = TemplateProcessing::builder() + .try_single(single)? + .try_pair(pair)? + .special_tokens(special_tokens) + .build()?; + + Ok(post_processor) +} diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 8559ae90..fa9f3637 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -78,11 +78,11 @@ pub struct Args { max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, + // waiting_served_ratio: f32, + // max_batch_prefill_tokens: u32, + // max_batch_total_tokens: Option, + // max_waiting_tokens: usize, + // max_batch_size: Option, revision: Option, validation_workers: usize, messages_api_enabled: bool, @@ -103,11 +103,11 @@ impl Args { max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, + // waiting_served_ratio: f32, + // max_batch_prefill_tokens: u32, + // max_batch_total_tokens: Option, + // max_waiting_tokens: usize, + // max_batch_size: Option, revision: Option, validation_workers: usize, messages_api_enabled: bool, @@ -125,11 +125,11 @@ impl Args { max_top_n_tokens, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, revision, validation_workers, messages_api_enabled, diff --git a/router/src/validation.rs b/router/src/validation.rs index 1913a2ce..3d1a4103 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -5,13 +5,12 @@ use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, }; use base64::{engine::general_purpose::STANDARD, Engine}; -use image::{io::Reader as ImageReader, ImageFormat}; +use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; -use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -96,7 +95,7 @@ impl Validation { &self, inputs: String, truncate: Option, - ) -> Result)>, ValidationError> { + ) -> Result)>, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -122,7 +121,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -181,11 +180,7 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok(( - vec![Chunk::Text(inputs).into()], - input_length, - max_new_tokens, - )) + Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) } } @@ -589,7 +584,7 @@ fn prepare_input( tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, -) -> Result<(tokenizers::Encoding, Vec), ValidationError> { +) -> Result<(tokenizers::Encoding, Vec), ValidationError> { use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { @@ -601,16 +596,16 @@ fn prepare_input( let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string())); tokenizer_query.push_str(&inputs[start..chunk_start]); } let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); + input_chunks.push(Chunk::Image(Image { data, mimetype })); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); start = chunk_end; } if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); + input_chunks.push(Chunk::Text(inputs[start..].to_string())); tokenizer_query.push_str(&inputs[start..]); } @@ -618,7 +613,7 @@ fn prepare_input( (tokenizer_query, input_chunks) } - _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), + _ => (inputs.clone(), vec![Chunk::Text(inputs)]), }; // Get the number of tokens in the input @@ -631,18 +626,51 @@ fn prepare_input( type TokenizerRequest = ( (String, Option), - oneshot::Sender), ValidationError>>, + oneshot::Sender), ValidationError>>, Span, ); +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Image { + pub data: Vec, + pub mimetype: String, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Chunk { + Text(String), + Image(Image), +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c { + Chunk::Text(text) => output.push_str(text), + Chunk::Image(Image { data, mimetype }) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + }); + output + } +} + #[derive(Debug, Clone)] -pub(crate) enum ValidGrammar { +pub enum ValidGrammar { Json(String), Regex(String), } #[derive(Debug, Clone)] -pub(crate) struct ValidParameters { +pub struct ValidParameters { /// / exponential scaling output probability distribution pub temperature: f32, /// / restricting to the k highest probability elements @@ -666,7 +694,7 @@ pub(crate) struct ValidParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidStoppingParameters { +pub struct ValidStoppingParameters { /// / Maximum number of generated tokens pub max_new_tokens: u32, /// / Optional stopping sequences @@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidGenerateRequest { - pub inputs: Vec, +pub struct ValidGenerateRequest { + pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, @@ -750,6 +778,8 @@ pub enum ValidationError { InvalidImageContent(String), #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), + #[error("{0} modality is not supported")] + UnsupportedModality(&'static str), } #[cfg(test)] diff --git a/update_doc.py b/update_doc.py index bfa7e4e9..428d4452 100644 --- a/update_doc.py +++ b/update_doc.py @@ -167,22 +167,24 @@ def check_openapi(check: bool): else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - errors = subprocess.run( + p = subprocess.run( [ - "swagger-cli", + "redocly", # allow for trailing whitespace since it's not significant # and the precommit hook will remove it - "validate", + "lint", filename, ], capture_output=True, - ).stderr.decode("utf-8") + ) + errors = p.stderr.decode("utf-8") # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 - if not errors.startswith("Swagger schema validation failed."): + print(errors) + if p.returncode != 0: print(errors) raise Exception( - f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + f"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\n {errors}" ) return True From 468e5c687476ebaf47560c572bcc891665f53379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 31 Jul 2024 13:08:41 +0200 Subject: [PATCH 178/340] Handle GPTQ-Marlin loading in `GPTQMarlinWeightLoader` (#2300) The `GPTWeightLoader` was structured like this in pseudocode: if marlin: Set up tensors in a way that GPTQ-Marlin expects else: Set up tensors in a way that ExLlama/GPTQ/AWQ expect However, the GPT-Marlin implementation details should really be in the `marlin` module. So move the former part out to a separate `GPTQMarlinWeightsLoader`. --- .../layers/gptq/__init__.py | 170 --------------- .../layers/marlin/__init__.py | 6 +- .../layers/marlin/gptq.py | 202 +++++++++++++++++- .../utils/quantization.py | 25 ++- 4 files changed, 224 insertions(+), 179 deletions(-) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 4fb151a9..f6616d3e 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -124,50 +124,7 @@ class GPTQWeightsLoader(WeightsLoader): self.sym = sym def get_weights(self, weights: Weights, prefix: str): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_tensor(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - scales = weights.get_tensor(f"{prefix}.scales") - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) use_exllama = True if self.bits != 4: @@ -248,11 +205,6 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -267,36 +219,6 @@ class GPTQWeightsLoader(WeightsLoader): scales = scales.to(dtype=weights.dtype) self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - if not self.sym: - qzeros = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) qzeros = weights.get_packed_sharded( f"{prefix}.qzeros", dim=1, block_sizes=block_sizes @@ -334,11 +256,6 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -353,41 +270,6 @@ class GPTQWeightsLoader(WeightsLoader): ) self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - - if not self.sym: - qzeros = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) qzeros = torch.cat( [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 @@ -441,59 +323,7 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_weights_row(self, weights: Weights, prefix: str): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - if self.desc_act or self.groupsize == -1: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) use_exllama = True if self.bits != 4: diff --git a/server/text_generation_server/layers/marlin/__init__.py b/server/text_generation_server/layers/marlin/__init__.py index 8a143f96..3ff3ed58 100644 --- a/server/text_generation_server/layers/marlin/__init__.py +++ b/server/text_generation_server/layers/marlin/__init__.py @@ -1,7 +1,6 @@ from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinLinear, - GPTQMarlinWeight, + GPTQMarlinWeightsLoader, can_use_gptq_marlin, repack_gptq_for_marlin, ) @@ -9,8 +8,7 @@ from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader __all__ = [ "GPTQMarlinFP8Linear", - "GPTQMarlinLinear", - "GPTQMarlinWeight", + "GPTQMarlinWeightsLoader", "MarlinWeightsLoader", "can_use_gptq_marlin", "repack_gptq_for_marlin", diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 80674b14..c7663b60 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Union import numpy import torch @@ -13,7 +13,7 @@ from text_generation_server.layers.marlin.util import ( ) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels @@ -48,6 +48,204 @@ def can_use_gptq_marlin( ) +class GPTQMarlinWeightsLoader(WeightsLoader): + """ + Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + if not self.sym: + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + if not self.sym: + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + if self.desc_act or self.groupsize == -1: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" + + @dataclass class GPTQMarlinWeight(Weight): """ diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index cfb8b1db..ee561acc 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download +from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.utils.weights import ( DefaultWeightsLoader, WeightsLoader, @@ -128,14 +129,32 @@ def get_loader( f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." ) - return GPTQWeightsLoader( + if can_use_gptq_marlin( bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ) + ): + from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader + + return GPTQMarlinWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + else: + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) elif quantize == "bitsandbytes": from text_generation_server.layers.bnb import BNBWeight From c73d1d604f5159264ce2f4b81b7cacbd540c6dbd Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 31 Jul 2024 10:27:15 -0400 Subject: [PATCH 179/340] Pr 2290 ci run (#2329) * MODEL_ID propagation fix * fix: remove global model id --------- Co-authored-by: root --- server/text_generation_server/models/flash_causal_lm.py | 3 +-- server/text_generation_server/models/globals.py | 9 --------- server/text_generation_server/server.py | 3 +-- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 152516e7..36bb2662 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,7 +43,6 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, - MODEL_ID, ) from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -1156,7 +1155,7 @@ class FlashCausalLM(Model): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) log_master( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ac42df30..8d2431db 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -29,15 +29,6 @@ if cuda_graphs is not None: CUDA_GRAPHS = cuda_graphs -# This is overridden at model loading. -MODEL_ID = None - - -def set_model_id(model_id: str): - global MODEL_ID - MODEL_ID = model_id - - # NOTE: eventually we should move this into the router and pass back the # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22bd759f..b92ab572 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,7 +30,7 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id, set_adapter_to_index +from text_generation_server.models.globals import set_adapter_to_index class SignalHandler: @@ -271,7 +271,6 @@ def serve( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) - set_model_id(model_id) asyncio.run( serve_inner( model_id, From 3c4f816ae34163199d0c4f77e4517da591f770a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Wed, 31 Jul 2024 16:29:07 +0200 Subject: [PATCH 180/340] refactor usage stats (#2339) * refactor usage stats * Update docs/source/usage_statistics.md Co-authored-by: Nicolas Patry * Update router/src/server.rs Co-authored-by: Nicolas Patry * changes based on feedback * run python3 udpate_doc.py * fix pre-commit * Update router/src/server.rs Co-authored-by: Nicolas Patry * delete option around usage stats arg --------- Co-authored-by: Nicolas Patry --- backends/v3/src/main.rs | 14 ++-- docs/source/basic_tutorials/launcher.md | 20 +++--- docs/source/usage_statistics.md | 4 +- launcher/src/main.rs | 47 +++++++++---- router/src/server.rs | 94 ++++++++++++------------- router/src/usage_stats.rs | 23 +++--- 6 files changed, 109 insertions(+), 93 deletions(-) diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index ef10514f..21952e66 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -1,5 +1,5 @@ use clap::{Parser, Subcommand}; -use text_generation_router::server; +use text_generation_router::{server, usage_stats}; use text_generation_router_v3::{connect_backend, V3Error}; use thiserror::Error; @@ -68,10 +68,8 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, - #[clap(long, env, default_value_t)] - disable_usage_stats: bool, - #[clap(long, env, default_value_t)] - disable_crash_reports: bool, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, } #[derive(Debug, Subcommand)] @@ -114,9 +112,8 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, messages_api_enabled, disable_grammar_support, - disable_usage_stats, - disable_crash_reports, max_client_batch_size, + usage_stats, } = args; if let Some(Commands::PrintSchema) = command { @@ -188,8 +185,7 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, - disable_usage_stats, - disable_crash_reports, + usage_stats, ) .await?; Ok(()) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index ce98876f..01f15648 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -431,20 +431,18 @@ Options: [env: LORA_ADAPTERS=] ``` -## DISABLE_USAGE_STATS +## USAGE_STATS ```shell - --disable-usage-stats - Disable sending of all usage statistics + --usage-stats + Control if anonymous usage stats are collected. Options are "on", "off" and "no-stack" Defaul is on - [env: DISABLE_USAGE_STATS=] + [env: USAGE_STATS=] + [default: on] -``` -## DISABLE_CRASH_REPORTS -```shell - --disable-crash-reports - Disable sending of crash reports, but allow anonymous usage statistics - - [env: DISABLE_CRASH_REPORTS=] + Possible values: + - on: Default option, usage statistics are collected anonymously + - off: Disables all collection of usage statistics + - no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event ``` ## HELP diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md index adf0d70f..a2c406ec 100644 --- a/docs/source/usage_statistics.md +++ b/docs/source/usage_statistics.md @@ -70,4 +70,6 @@ As of release 2.1.2 this is an example of the data collected: ## How to opt-out -You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics. +By passing the `--usage-stats` to the text-generation-launcher you can control how much usage statistics are being collected. +`--usage-stats=no-stack` will not emit the stack traces from errors and the error types, but will continue to send start and stop events +`--usage-stats=off` will completely disable everything diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 15233306..dd4a92ec 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -170,6 +170,33 @@ impl std::fmt::Display for RopeScaling { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum UsageStatsLevel { + /// Default option, usage statistics are collected anonymously + On, + /// Disables all collection of usage statistics + Off, + /// Doesn't send the error stack trace or error type, but allows sending a crash event + NoStack, +} + +impl std::fmt::Display for UsageStatsLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + UsageStatsLevel::On => { + write!(f, "on") + } + UsageStatsLevel::Off => { + write!(f, "off") + } + UsageStatsLevel::NoStack => { + write!(f, "no-stack") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -468,13 +495,11 @@ struct Args { #[clap(long, env)] lora_adapters: Option, - /// Disable sending of all usage statistics - #[clap(default_value = "false", long, env)] - disable_usage_stats: bool, - - /// Disable sending of crash reports, but allow anonymous usage statistics - #[clap(default_value = "false", long, env)] - disable_crash_reports: bool, + /// Control if anonymous usage stats are collected. + /// Options are "on", "off" and "no-stack" + /// Defaul is on. + #[clap(default_value = "on", long, env)] + usage_stats: UsageStatsLevel, } #[derive(Debug)] @@ -1222,12 +1247,8 @@ fn spawn_webserver( ]; // Pass usage stats flags to router - if args.disable_usage_stats { - router_args.push("--disable-usage-stats".to_string()); - } - if args.disable_crash_reports { - router_args.push("--disable-crash-reports".to_string()); - } + router_args.push("--usage-stats".to_string()); + router_args.push(args.usage_stats.to_string()); // Grammar support if args.disable_grammar_support { diff --git a/router/src/server.rs b/router/src/server.rs index ccbd1535..dcbaa2ad 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,14 +7,13 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; -use crate::usage_stats; use crate::validation::ValidationError; use crate::{ - BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, - GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, - HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, - SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse, - ToolCallDelta, ToolCallMessage, Url, Usage, Validation, + usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, + GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, + HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, + OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, + TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1505,8 +1504,7 @@ pub async fn run( messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, - disable_usage_stats: bool, - disable_crash_reports: bool, + usage_stats_level: usage_stats::UsageStatsLevel, ) -> Result<(), WebServerError> { // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue @@ -1698,33 +1696,32 @@ pub async fn run( // Only send usage stats when TGI is run in container and the function returns Some let is_container = matches!(usage_stats::is_container(), Ok(true)); - - let user_agent = if !disable_usage_stats && is_container { - let reduced_args = usage_stats::Args::new( - config.clone(), - tokenizer_config.tokenizer_class.clone(), - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - // waiting_served_ratio, - // max_batch_prefill_tokens, - // max_batch_total_tokens, - // max_waiting_tokens, - // max_batch_size, - revision.clone(), - validation_workers, - messages_api_enabled, - disable_grammar_support, - max_client_batch_size, - disable_usage_stats, - disable_crash_reports, - ); - Some(usage_stats::UserAgent::new(reduced_args)) - } else { - None + let user_agent = match (usage_stats_level, is_container) { + (usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_config.tokenizer_class.clone(), + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, + revision.clone(), + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats_level, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } + _ => None, }; if let Some(ref ua) = user_agent { @@ -1780,21 +1777,18 @@ pub async fn run( Ok(()) } Err(e) => { - if !disable_crash_reports { - let error_event = usage_stats::UsageStatsEvent::new( - ua.clone(), - usage_stats::EventType::Error, - Some(e.to_string()), - ); - error_event.send().await; - } else { - let unknow_error_event = usage_stats::UsageStatsEvent::new( - ua.clone(), - usage_stats::EventType::Error, - Some("unknow_error".to_string()), - ); - unknow_error_event.send().await; - } + let description = match usage_stats_level { + usage_stats::UsageStatsLevel::On => Some(e.to_string()), + usage_stats::UsageStatsLevel::NoStack => Some("unknow_error".to_string()), + _ => None, + }; + let event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + description, + ); + event.send().await; + Err(e) } } diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index fa9f3637..0282ac63 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -1,4 +1,5 @@ use crate::config::Config; +use clap::ValueEnum; use csv::ReaderBuilder; use reqwest::header::HeaderMap; use serde::Serialize; @@ -13,6 +14,13 @@ use uuid::Uuid; const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; +#[derive(Copy, Clone, Debug, Serialize, ValueEnum)] +pub enum UsageStatsLevel { + On, + NoStack, + Off, +} + #[derive(Debug, Clone, Serialize)] pub struct UserAgent { pub uid: String, @@ -71,7 +79,7 @@ impl UsageStatsEvent { #[derive(Debug, Clone, Serialize)] pub struct Args { model_config: Option, - tokenizer_config: Option, + tokenizer_class: Option, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, @@ -88,15 +96,14 @@ pub struct Args { messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, - disable_usage_stats: bool, - disable_crash_reports: bool, + usage_stats_level: UsageStatsLevel, } impl Args { #[allow(clippy::too_many_arguments)] pub fn new( model_config: Option, - tokenizer_config: Option, + tokenizer_class: Option, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, @@ -113,12 +120,11 @@ impl Args { messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, - disable_usage_stats: bool, - disable_crash_reports: bool, + usage_stats_level: UsageStatsLevel, ) -> Self { Self { model_config, - tokenizer_config, + tokenizer_class, max_concurrent_requests, max_best_of, max_stop_sequences, @@ -135,8 +141,7 @@ impl Args { messages_api_enabled, disable_grammar_support, max_client_batch_size, - disable_usage_stats, - disable_crash_reports, + usage_stats_level, } } } From d70da59c25419072c2723163ae547459008e6083 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 1 Aug 2024 17:08:36 +0800 Subject: [PATCH 181/340] enable HuggingFaceM4/idefics-9b in intel gpu (#2338) Signed-off-by: Wang, Yi A --- .../models/custom_modeling/idefics_modeling.py | 14 +++++++++++++- server/text_generation_server/models/idefics.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 771bc234..fc6becc4 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -351,7 +351,19 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + residual is not None, + ) + return out + elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index af224254..29929b98 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -20,6 +20,8 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.quantization import get_loader +from text_generation_server.utils.import_utils import SYSTEM + class IDEFICSSharded(IdeficsCausalLM): def __init__( @@ -37,6 +39,14 @@ class IDEFICSSharded(IdeficsCausalLM): # 9b seems to work correctly enough in float16, but 80b seems # to be really saturating for f16. dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype From ccddb30c029ecaf82e01ef3523b2bcba612fbe9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 1 Aug 2024 15:38:57 +0200 Subject: [PATCH 182/340] Fix cache block size for flash decoding (#2351) * Fix cache block size for flash decoding This seems to have been accidentally dropped during the TRT-LLM PR rebase. * Also run CI on changes to `backends` --- .github/workflows/ci_build.yaml | 1 + backends/v3/src/backend.rs | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index d62297e4..5ca2854a 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -10,6 +10,7 @@ on: paths: - ".github/workflows/build.yaml" - "integration-tests/**" + - "backends/**" - "server/**" - "proto/**" - "router/**" diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 49e2bc8f..d82355de 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,9 +35,16 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + let queue = Queue::new( requires_padding, - 16, + block_size, window_size, speculate, max_batch_total_tokens, From 48fec7b1981993105249e6281ce6e7e8a02c3853 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 1 Aug 2024 17:03:28 +0200 Subject: [PATCH 183/340] Unify attention output handling (#2343) - Always return the hidden states. - Create the output tensor inside the `attention` and `paged_attention` functions. This removes the difference between how the output is handled between attention (output parameter) and paged attention (return value). This also removes the assumption that the attention implementation can write to an output tensor (in preparation of FlashInfer). --- .../layers/attention/cuda.py | 17 ++++++++++------- .../layers/attention/ipex.py | 5 +++-- .../layers/attention/rocm.py | 9 ++++++--- .../custom_modeling/flash_cohere_modeling.py | 7 +------ .../custom_modeling/flash_dbrx_modeling.py | 7 +------ .../flash_deepseek_v2_modeling.py | 9 ++------- .../custom_modeling/flash_gemma2_modeling.py | 7 +------ .../custom_modeling/flash_gemma_modeling.py | 7 +------ .../custom_modeling/flash_gpt2_modeling.py | 7 +------ .../custom_modeling/flash_llama_modeling.py | 7 +------ .../custom_modeling/flash_mistral_modeling.py | 7 +------ .../custom_modeling/flash_mixtral_modeling.py | 7 +------ .../custom_modeling/flash_neox_modeling.py | 7 +------ .../custom_modeling/flash_phi_modeling.py | 7 +------ .../custom_modeling/flash_qwen2_modeling.py | 7 +------ .../models/custom_modeling/flash_rw_modeling.py | 14 ++------------ .../flash_santacoder_modeling.py | 7 +------ .../flash_starcoder2_modeling.py | 7 +------ 18 files changed, 36 insertions(+), 109 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index c0c4da4d..dff742dc 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -34,7 +34,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -85,7 +84,7 @@ def paged_attention( # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. if softcap is None: softcap = 0.0 - out2 = flash_attn_2_cuda.varlen_fwd( + out = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, @@ -108,13 +107,15 @@ def paged_attention( False, # return softmax None, # generator ) - return out2[0] + return out[0] else: if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths from vllm._C import ops + out = torch.empty_like(query) + use_v1 = max_s <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512 ) @@ -200,13 +201,13 @@ except ImportError: SUPPORTS_WINDOWING = V2 + if V2: def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -214,6 +215,7 @@ if V2: causal=True, softcap=0.0, ): + out = torch.empty_like(q) if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( @@ -238,7 +240,7 @@ if V2: softcap, False, None, - ) + )[0] else: @@ -246,7 +248,6 @@ else: q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -286,6 +287,8 @@ else: .reshape(original_shape[0], -1, original_shape[2]) ) + out = torch.empty_like(q) + return flash_attn_cuda.fwd( q, k, @@ -302,4 +305,4 @@ else: False, 0, None, - ) + )[0] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 45a0a03e..e0956b26 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -10,13 +10,14 @@ def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, @@ -49,7 +50,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -59,6 +59,7 @@ def paged_attention( seqlen: Seqlen, max_s: int, ): + out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 3ebe492a..69e64162 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -39,7 +39,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -72,6 +71,8 @@ def paged_attention( max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = input_lengths.input_lengths + out = torch.empty_like(query) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -174,7 +175,6 @@ if ENGINE == "ck": q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -184,6 +184,8 @@ if ENGINE == "ck": if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, @@ -209,13 +211,14 @@ elif ENGINE == "triton": q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index b73dada6..e02a31d9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 3c301c8d..d3d1d1ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 3e84b4a8..0905d3c2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # Output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, ) # Decode else: - paged_attention( - attn_output, + attn_output = paged_attention( query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 0dc5b9cf..de86f514 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index dfe6510c..178efadb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a55a4af3..a19cff8c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b55ddc23..9ea19a87 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3d8f6bf4..dda53ff3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 7cdca553..85431c6c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 623b164c..b1b03ad7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module): reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(qkv[:, 0]) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, qkv[:, 0], kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a1ce03b9..a9e18348 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module): # Reshape key and value and cache reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index e357a287..865cc85d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d7cad480..708641e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module): slots, ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2b939a10..c2676782 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module): key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index cfa891d4..e562eb89 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], From 688321bcc46ad4542ac619cabe710333d53d64a4 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 Aug 2024 09:11:40 -0400 Subject: [PATCH 184/340] fix: attempt forward on flash attn2 to check hardware support (#2335) * fix: attempt forward on flash attn2 to check hardware support * fix: warn window_size_left when using flash attn 1 * fix: prefer version check over test op and avoid window_size_left if not flash attn2 * fix: improve condtional and error message * fix: update sliding window conditional * fix: simplify changes and revert model changes * fix: avoid changing conditional * fix: typo tweak --- server/text_generation_server/layers/attention/cuda.py | 4 ++++ server/text_generation_server/models/__init__.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc..2b898831 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -172,6 +172,10 @@ def paged_attention( try: + is_ampere_or_newer = major >= 8 and minor >= 0 + if not is_ampere_or_newer: + raise ImportError("FlashAttention only supports Ampere GPUs or newer.") + import flash_attn_2_cuda V2 = True diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dc24159..ae791ef8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -484,6 +484,9 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) + if max_input_tokens is not None and max_input_tokens <= sliding_window: + sliding_window = -1 + if ( (sliding_window is not None and sliding_window != -1) and not SUPPORTS_WINDOWING From 8b0f5feb020096324f413e8abe7a2841b4979a3e Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 Aug 2024 12:36:44 -0400 Subject: [PATCH 185/340] feat: include local lora adapter loading docs (#2359) --- docs/source/conceptual/lora.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/conceptual/lora.md b/docs/source/conceptual/lora.md index 08df767c..cfc2109b 100644 --- a/docs/source/conceptual/lora.md +++ b/docs/source/conceptual/lora.md @@ -36,6 +36,18 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia ``` +additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example: + +```bash +LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter +``` + +note it's possible to mix adapter_ids with adapter_id=adapter_path e.g. + +```bash +LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/ +``` + In the server logs, you will see the following message: ```txt From 83d1f23fea9bf8fdb73662549ecb3d24a97abde1 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 07:49:53 -0400 Subject: [PATCH 186/340] fix: return the out tensor rather then the functions return value (#2361) --- server/text_generation_server/layers/attention/cuda.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 2b898831..96b654d0 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -292,8 +292,7 @@ else: ) out = torch.empty_like(q) - - return flash_attn_cuda.fwd( + flash_attn_cuda.fwd( q, k, v, @@ -309,4 +308,5 @@ else: False, 0, None, - )[0] + ) + return out From 88e07f12cc3b6e138becf3abee9265c231ca1599 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 07:51:32 -0400 Subject: [PATCH 187/340] feat: implement a templated endpoint for visibility into chat requests (#2333) * feat: implement a templated endpoint for visibility into chat requests * feat: improve to tokenize too * fix: adjust return type * feat: simplify prepare_chat_input logic and adjust start stop chars --- router/src/lib.rs | 6 ++ router/src/server.rs | 210 ++++++++++++++++++++++++++++++------------- 2 files changed, 156 insertions(+), 60 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270..386b0556 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse { pub details: Option
, } +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatTokenizeResponse { + pub(crate) tokenize_response: TokenizeResponse, + pub(crate) templated_text: String, +} + #[derive(Serialize, ToSchema)] #[serde(transparent)] pub(crate) struct TokenizeResponse(Vec); diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ad..7655182a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,6 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; +use crate::ChatTokenizeResponse; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -22,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -115,6 +116,107 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/chat_tokenize", + request_body = ChatRequest, + responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) +)] +async fn get_chat_tokenize( + Extension(infer): Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + metrics::counter!("tgi_request_count").increment(1); + + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + .. + } = req; + + let tool_prompt = tool_prompt.unwrap_or_default(); + let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; + + let generate_request = GenerateRequest { + inputs, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty: None, + frequency_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens: max_tokens, + return_full_text: None, + stop: stop.unwrap_or_default(), + truncate: None, + watermark: false, + details: false, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: _grammar, + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + }, + }; + + let input = generate_request.inputs.clone(); + let encoding = infer.tokenize(generate_request).await?; + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } +} + #[utoipa::path( get, tag = "Text Generation Inference", @@ -1034,63 +1136,14 @@ async fn chat_completions( Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - - // response_format and tools are mutually exclusive - if response_format.is_some() && tools.as_ref().is_some() { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Grammar and tools are mutually exclusive".to_string(), - error_type: "grammar and tools".to_string(), - }), - )); - } - - // extract tool grammar if present - let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { - Ok(grammar) => grammar, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; - - // determine the appropriate arguments for apply_chat_template - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - - let (tools_grammar_prompt, grammar) = match response_format { - Some(response_format) => (None, Some(response_format)), - None => ( - tools_grammar_prompt.clone(), - tools_grammar_prompt.map(|(grammar, _)| grammar.clone()), - ), - }; - - // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { - Ok(inputs) => inputs, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; + let (inputs, grammar, tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; // build the request passing some parameters let generate_request = GenerateRequest { @@ -1360,8 +1413,11 @@ async fn tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = - String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); SimpleToken { id, text, @@ -2036,6 +2092,7 @@ async fn start( } let info_routes = Router::new() .route("/", get(health)) + .route("/chat_tokenize", post(get_chat_tokenize)) .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) @@ -2332,3 +2389,36 @@ fn create_post_processor( Ok(post_processor) } + +type PreparedInput = (String, Option, Option); + +fn prepare_chat_input( + infer: &Infer, + response_format: Option, + tools: Option>, + tool_choice: ToolChoice, + tool_prompt: &str, + messages: Vec, +) -> Result { + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + if let Some(format) = response_format { + let inputs = infer.apply_chat_template(messages, None)?; + return Ok((inputs, Some(format), None)); + } + + // if tools are set, apply the tool grammar and then the chat template + let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; + let grammar = tool_grammar + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + let tools_grammar_prompt = tool_grammar + .as_ref() + .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); + let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + Ok((inputs, grammar, tool_grammar)) +} From b4562e1369a3506d702d6ca539b4871daa107a4b Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:09:50 -0400 Subject: [PATCH 188/340] feat: prefer stop over eos_token to align with openai finish_reason (#2344) --- router/src/lib.rs | 11 ++++++++++- router/src/server.rs | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 386b0556..a956b058 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -619,7 +619,7 @@ impl ChatCompletion { message, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), - finish_reason: details.finish_reason.to_string(), + finish_reason: details.finish_reason.format(true), }], usage: Usage { prompt_tokens: details.prefill.len() as u32, @@ -1117,6 +1117,15 @@ impl std::fmt::Display for FinishReason { } } +impl FinishReason { + pub fn format(&self, use_stop: bool) -> String { + match self { + FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(), + _ => self.to_string(), + } + } +} + #[derive(Serialize, ToSchema)] pub(crate) struct BestOfSequence { #[schema(example = "test")] diff --git a/router/src/server.rs b/router/src/server.rs index 7655182a..4b6fe50c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1021,7 +1021,7 @@ async fn completions( total_tokens += details.prefill.len() as u32 + details.generated_tokens; Ok(CompletionComplete { - finish_reason: details.finish_reason.to_string(), + finish_reason: details.finish_reason.format(true), index: index as u32, logprobs: None, text: generation.generated_text, @@ -1212,7 +1212,7 @@ async fn chat_completions( tool_calls, current_time, logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + stream_token.details.map(|d| d.finish_reason.format(true)), ), )) .unwrap_or_else(|e| { From 5400c7155dabc78579c0219093dc8e2b98f7b254 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:10:19 -0400 Subject: [PATCH 189/340] feat: return the generated text when parsing fails (#2353) --- router/src/server.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 4b6fe50c..1d1cd36a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1246,9 +1246,13 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - let gen_text_value: Value = serde_json::from_str(&generation.generated_text) - .map_err(|e| InferError::ToolError(e.to_string()))?; - + let gen_text_value: Value = + serde_json::from_str(&generation.generated_text).map_err(|e| { + InferError::ToolError(format!( + "Failed to parse generated text: {} {:?}", + e, generation.generated_text + )) + })?; let function = gen_text_value.get("function").ok_or(InferError::ToolError( "No function found in generated text".to_string(), ))?; From db873be1772e0bde079573937fe4a95e56d6764f Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:33:22 -0400 Subject: [PATCH 190/340] fix: default num_ln_in_parallel_attn to one if not supplied (#2364) --- .../models/custom_modeling/flash_rw_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 708641e7..0691da9b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix: str, weights): super().__init__() - self.num_ln = config.num_ln_in_parallel_attn + # Falcon2 includes the number of layer norms in the config + # in the case no number of layer norms is provided, we default to 1 + self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) if self.num_ln == 1: self.input_ln = FastLayerNorm.load( From 3ccde430d95157a10a6272950e3b6a6031db6b53 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 15:25:30 -0400 Subject: [PATCH 191/340] fix: prefer original layernorm names for 180B (#2365) --- .../models/custom_modeling/flash_rw_modeling.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 0691da9b..fc002082 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module): prefix = f"{prefix}.h.{layer_id}" + # NOTE: Falcon 180B uses the ln_attn prefix + ln_prefix = "input_layernorm" + if config.num_hidden_layers == 80: + ln_prefix = "ln_attn" + self.input_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.{ln_prefix}", weights=weights, eps=config.layer_norm_epsilon, ) @@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module): # in the case no number of layer norms is provided, we default to 1 self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) + # Falcon 180B uses the ln_attn prefix and has 2 layer norms + if config.num_hidden_layers == 80: + self.num_ln = 2 + if self.num_ln == 1: self.input_ln = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", From 11fab8a20c4439c8ee0a40f21cc7a1465596fb82 Mon Sep 17 00:00:00 2001 From: almersawi <43927639+almersawi@users.noreply.github.com> Date: Thu, 8 Aug 2024 03:45:23 +0400 Subject: [PATCH 192/340] fix: fix num_ln_in_parallel_attn attribute name typo in RWConfig (#2350) Co-authored-by: Islam Almersawi --- .../models/custom_modeling/flash_rw_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fc002082..10f995a3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -94,7 +94,7 @@ class RWConfig(PretrainedConfig): else kwargs.pop("n_head", 8) ) self.layer_norm_epsilon = layer_norm_epsilon - self.num_ln_in_parallel_attention = num_ln_in_prallel_attention + self.num_ln_in_parallel_attn = num_ln_in_prallel_attention self.initializer_range = initializer_range self.use_cache = use_cache self.hidden_dropout = hidden_dropout From 3ea8e8a2d5d9bc0683e2bb3f65e0e883f74c5592 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Aug 2024 21:32:37 -0400 Subject: [PATCH 193/340] add gptj modeling in TGI #2366 (CI RUN) (#2372) * add gptj modeling Signed-off-by: Wang, Yi A * fix: update docs for model addition * fix: adjust syntax typo * fix: adjust syntax typo again --------- Signed-off-by: Wang, Yi A Co-authored-by: Wang, Yi A --- docs/source/supported_models.md | 1 + router/src/config.rs | 1 + .../text_generation_server/models/__init__.py | 43 ++ .../custom_modeling/flash_gptj_modeling.py | 405 ++++++++++++++++++ 4 files changed, 450 insertions(+) create mode 100644 server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index bc124f31..b78104df 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -32,6 +32,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) - [Gpt2](https://huggingface.co/openai-community/gpt2) - [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b) - [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal) diff --git a/router/src/config.rs b/router/src/config.rs index 7737165e..5d0be9c8 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -153,6 +153,7 @@ pub enum Config { Bloom, Mpt, Gpt2, + Gptj, GptNeox, Phi, #[serde(rename = "phi-msft")] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ae791ef8..1f9c7526 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -132,6 +132,9 @@ try: from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( FlashGPT2ForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_gptj_modeling import ( + FlashGPTJForCausalLM, + ) from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) @@ -294,6 +297,11 @@ class ModelType(enum.Enum): "name": "Gpt Neox", "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", } + GPTJ = { + "type": "gptj", + "name": "Gptj", + "url": "https://huggingface.co/EleutherAI/gpt-j-6b", + } IDEFICS = { "type": "idefics", "name": "Idefics", @@ -641,6 +649,41 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == GPTJ: + if FLASH_ATTENTION: + try: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTJForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + except RuntimeError as e: + # Lots of legacy models with various weight names. + log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py new file mode 100644 index 00000000..eb667384 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -0,0 +1,405 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.utils.import_utils import SYSTEM + + +def load_attention(config, prefix: str, weights): + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias) + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +class GPTJRotary(PositionRotaryEmbedding): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + import rotary_emb + + q1 = query[..., ::2] + q2 = query[..., 1::2] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., ::2] + k2 = key[..., 1::2] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + from vllm._C import ops + + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + +class FlashGPTJAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = config.rotary_dim + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.query_key_value = load_attention( + config, + prefix=prefix, + weights=weights, + ) + + self.o_proj = load_row( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=False, + ) + + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + self.rotary_emb = GPTJRotary.static( + config=config, + dim=self.rotary_dim, + base=10000, + device=weights.device, + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + query, key, value = self.query_key_value(hidden_states).split( + self.head_size * self.num_heads, dim=1 + ) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + + # Compute rotary embeddings on rotary_ndims + if self.rotary_dim is not None: + self.rotary_emb( + query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin + ) + else: + self.rotary_emb(query, key, cos, sin) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + key, + value, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GPTJMLP(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + act = config.activation_function + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.fc_in = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc_in", weights=weights, bias=True + ) + + self.fc_out = load_row( + config, + prefix=f"{prefix}.fc_out", + weights=weights, + bias=True, + ) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + return self.fc_out(hidden_states) + + +class FlashGPTJLayer(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.self_attn = FlashGPTJAttention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + hidden_states, residual = self.input_layernorm(hidden_states, residual) + # Self Attention + attn_output = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + feed_forward_hidden_states = self.mlp(hidden_states) + + return attn_output + feed_forward_hidden_states, residual + + +class FlashGPTJModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + + self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights) + self.layers = nn.ModuleList( + [ + FlashGPTJLayer( + prefix=( + f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}" + ), + config=config, + weights=weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.ln_f = FastLayerNorm.load( + prefix="ln_f" if not prefix else f"{prefix}.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def forward( + self, + input_ids: Optional[torch.LongTensor], + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states + + +class FlashGPTJForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + self.model = FlashGPTJModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits From 9b1b545bb47a596aff6c65abadd0c1d626578192 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Aug 2024 23:14:02 -0400 Subject: [PATCH 194/340] Fix the prefix for OPT model in opt_modelling.py #2370 (CI RUN) (#2371) * Fix the bug * fix: run lints * fix: small syntax tweak --------- Co-authored-by: Sadra Barikbin --- integration-tests/models/test_opt.py | 19 ++++++++++++++ .../models/custom_modeling/opt_modeling.py | 25 ++++++++++--------- 2 files changed, 32 insertions(+), 12 deletions(-) create mode 100644 integration-tests/models/test_opt.py diff --git a/integration-tests/models/test_opt.py b/integration-tests/models/test_opt.py new file mode 100644 index 00000000..cbeb6376 --- /dev/null +++ b/integration-tests/models/test_opt.py @@ -0,0 +1,19 @@ +import pytest + + +@pytest.fixture(scope="module") +def opt_sharded_handle(launcher): + with launcher("facebook/opt-6.7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def opt_sharded(opt_sharded_handle): + await opt_sharded_handle.health(300) + return opt_sharded_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +async def test_opt(opt_sharded): + pass diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 84a1c069..bd440321 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") + weights.get_tensor( + f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + ) ) def forward( @@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix}.decoder.layers.{layer_id}" + prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel): self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size + prefix = prefix + "." if prefix else "" + self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.decoder.embed_tokens", weights=weights + prefix=f"{prefix}decoder.embed_tokens", weights=weights ) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( config, - prefix=f"{prefix}.decoder.project_out", + prefix=f"{prefix}decoder.project_out", weights=weights, bias=False, ) @@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel): if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( config, - prefix=f"{prefix}.decoder.project_in", + prefix=f"{prefix}decoder.project_in", weights=weights, bias=False, ) @@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None @@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, prefix, config, weights): super().__init__(config) - if not prefix: - prefix = "model" - else: - prefix = f"{prefix}.model" - self.model = OPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights + config, + prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", + weights=weights, ) def forward( From 06b638f3108d4c28c1a31539f76f9f6f3dcb33fb Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Aug 2024 11:14:06 -0400 Subject: [PATCH 195/340] Pr 2374 ci branch (#2378) * Update __init__.py Fix issue with NoneType comparison for max_input_tokens and sliding_window - Add default values for max_input_tokens and sliding_window to handle None cases. - Ensure the comparison between max_input_tokens and sliding_window is handled correctly to prevent TypeError. - This change addresses the error: TypeError: '<=' not supported between instances of 'int' and 'NoneType'. * Update __init__.py Handle NoneType in sliding_window comparison to fix TypeError in __init__.py by ensuring the comparison logic accounts for NoneType values, preventing errors and improving code robustness. * fix: syntax/style tweak --------- Co-authored-by: Praz --- server/text_generation_server/models/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1f9c7526..da14d083 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -490,7 +490,12 @@ def get_model( raise RuntimeError( "Sharding is currently not supported with `exl2` quantization" ) - sliding_window = config_dict.get("sliding_window", -1) + + sliding_window = ( + config_dict.get("sliding_window") + if config_dict.get("sliding_window") is not None + else -1 + ) if max_input_tokens is not None and max_input_tokens <= sliding_window: sliding_window = -1 From 3893d00927626294ff45dd1500f791c8bf575a81 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 9 Aug 2024 00:08:52 +0800 Subject: [PATCH 196/340] fix EleutherAI/gpt-neox-20b does not work in tgi (#2346) Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_neox_modeling.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b1b03ad7..67237d5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + # Compute rotary embeddings on rotary_ndims + query_rot = qkv[:, 0][..., : self.rotary_dim] + query_pass = qkv[:, 0][..., self.rotary_dim :] + key_rot = qkv[:, 1][..., : self.rotary_dim] + key_pass = qkv[:, 1][..., self.rotary_dim :] + # Inplace rotary - self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) + self.rotary_emb(query_rot, key_rot, cos, sin) + qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) + qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) From 1057f281282d957e357f48baf907662acfd1da56 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Aug 2024 12:30:29 -0400 Subject: [PATCH 197/340] Pr 2337 ci branch (#2379) * hotfix: fix xpu crash brought by code refine. torch.xpu rely on import ipex Signed-off-by: Wang, Yi A * reable gemma2 in xpu Signed-off-by: Wang, Yi A * fix in regression in ipex flashattention Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A Co-authored-by: Wang, Yi A --- server/text_generation_server/layers/attention/ipex.py | 7 ++++++- server/text_generation_server/utils/import_utils.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index e0956b26..d7cf780a 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -2,6 +2,7 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen +from typing import Optional SUPPORTS_WINDOWING = False @@ -15,11 +16,12 @@ def attention( softmax_scale, window_size_left=-1, causal=True, + softcap: Optional[float] = None, ): out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - return ipex.llm.functional.varlen_attention( + ipex.llm.functional.varlen_attention( q, k, v, @@ -36,6 +38,8 @@ def attention( None, ) + return out + def reshape_and_cache( key: torch.Tensor, @@ -58,6 +62,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + softcap: Optional[float] = None, ): out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 7c053014..782b4f15 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -56,6 +56,8 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): get_free_memory = get_cuda_free_memory elif is_ipex_available(): SYSTEM = "ipex" + import intel_extension_for_pytorch # noqa: F401 + if hasattr(torch, "xpu") and torch.xpu.is_available(): empty_cache = torch.xpu.empty_cache synchronize = torch.xpu.synchronize From 853fb96fec6979b3c3f30f8ed43b8ae1c7651570 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Aug 2024 14:08:56 -0400 Subject: [PATCH 198/340] fix: prefer hidden_activation over hidden_act in gemma2 (#2381) --- .../models/custom_modeling/flash_gemma2_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index de86f514..54d212e6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -265,7 +265,7 @@ class FlashGemma2Attention(torch.nn.Module): class Gemma2MLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - act = config.hidden_act + act = config.hidden_activation self.act = ( ACT2FN[act] if "gelu" not in act From b1bc0ecb7fd5805a0fd340bae24b76189a0462b2 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 8 Aug 2024 22:06:57 +0200 Subject: [PATCH 199/340] Update Quantization docs and minor doc fix. (#2368) * Update Quantization docs and minor doc fix. * update readme with latest quants info * Apply suggestions from code review Co-authored-by: Pedro Cuenca * up --------- Co-authored-by: Pedro Cuenca --- .redocly.lint-ignore.yaml | 79 +++++++++++++++++++ docs/openapi.json | 2 +- .../source/basic_tutorials/preparing_model.md | 2 +- .../basic_tutorials/visual_language_models.md | 2 +- docs/source/conceptual/quantization.md | 59 ++++++++------ 5 files changed, 118 insertions(+), 26 deletions(-) create mode 100644 .redocly.lint-ignore.yaml diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml new file mode 100644 index 00000000..382c9ab6 --- /dev/null +++ b/.redocly.lint-ignore.yaml @@ -0,0 +1,79 @@ +# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API. +# See https://redoc.ly/docs/cli/ for more information. +docs/openapi.json: + no-empty-servers: + - '#/openapi' + spec: + - >- + #/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum + - >- + #/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/grammar/nullable' + - >- + #/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum + - '#/components/schemas/GenerateResponse/properties/details/nullable' + - '#/components/schemas/StreamResponse/properties/details/nullable' + - '#/components/schemas/ChatRequest/properties/response_format/nullable' + - '#/components/schemas/ChatRequest/properties/tool_choice/nullable' + - '#/components/schemas/ToolChoice/nullable' + - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable' + - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable' + no-invalid-media-type-examples: + - '#/paths/~1/post/responses/422/content/application~1json/example' + - '#/paths/~1/post/responses/424/content/application~1json/example' + - '#/paths/~1/post/responses/429/content/application~1json/example' + - '#/paths/~1/post/responses/500/content/application~1json/example' + - '#/paths/~1generate/post/responses/422/content/application~1json/example' + - '#/paths/~1generate/post/responses/424/content/application~1json/example' + - '#/paths/~1generate/post/responses/429/content/application~1json/example' + - '#/paths/~1generate/post/responses/500/content/application~1json/example' + - >- + #/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example + - '#/paths/~1tokenize/post/responses/404/content/application~1json/example' + - >- + #/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/500/content/application~1json/example + operation-4xx-response: + - '#/paths/~1health/get/responses' + - '#/paths/~1info/get/responses' + - '#/paths/~1metrics/get/responses' + no-unused-components: + - '#/components/schemas/Completion' + security-defined: + - '#/paths/~1/post' + - '#/paths/~1generate/post' + - '#/paths/~1generate_stream/post' + - '#/paths/~1health/get' + - '#/paths/~1info/get' + - '#/paths/~1metrics/get' + - '#/paths/~1tokenize/post' + - '#/paths/~1v1~1chat~1completions/post' + - '#/paths/~1v1~1completions/post' diff --git a/docs/openapi.json b/docs/openapi.json index 4d3a2398..657d6d19 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} +} \ No newline at end of file diff --git a/docs/source/basic_tutorials/preparing_model.md b/docs/source/basic_tutorials/preparing_model.md index 71ca5598..456ade44 100644 --- a/docs/source/basic_tutorials/preparing_model.md +++ b/docs/source/basic_tutorials/preparing_model.md @@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects. ## Quantization -TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) +TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) ## RoPE Scaling diff --git a/docs/source/basic_tutorials/visual_language_models.md b/docs/source/basic_tutorials/visual_language_models.md index 3770db0b..f152a2f0 100644 --- a/docs/source/basic_tutorials/visual_language_models.md +++ b/docs/source/basic_tutorials/visual_language_models.md @@ -84,7 +84,7 @@ print(chat) ``` -or with OpenAi's library: +or with OpenAI's [client library](https://github.com/openai/openai-python): ```python from openai import OpenAI diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 8f26fdba..7507687f 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -1,6 +1,40 @@ # Quantization -TGI offers GPTQ and bits-and-bytes quantization to quantize large language models. +TGI offers many quantization schemes to run LLMs effectively and fast based on your use-case. TGI supports GPTQ, AWQ, bits-and-bytes, EETQ, Marlin, EXL2 and fp8 quantization. + +To leverage GPTQ, AWQ, Marlin and EXL2 quants, you must provide pre-quantized weights. Whereas for bits-and-bytes, EETQ and fp8, weights are quantized by TGI on the fly. + +We recommend using the official quantization scripts for creating your quants: +1. [AWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quantize.py) +2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) +3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) + +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. + +## Quantization with bitsandbytes + +bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. + +8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much. +In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes +``` + +4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. + +In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 +``` + +You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). + +Use `eetq` or `fp8` for other quantization schemes. + +In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset. ## Quantization with GPTQ @@ -35,25 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). - -## Quantization with bitsandbytes - -bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. - -8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much. -In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 - -```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes -``` - -4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. - -In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 - -```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 -``` - -You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file From 6f2a468a641c9e9554f94ae03925b11daf89906c Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 04:54:32 -0400 Subject: [PATCH 200/340] Pr 2352 ci branch (#2382) * Fix unsigned integer underflow Passing --max-batch-size to the launcher actually had no effect because after a few requests the max_size passed to State::next_batch would underflow becoming a largo positive number. In the scheduler, as soon as the cached batch size reached the max_batch_size the max_size passed to next_batch becomes 0. Since the only check in that funcion is ``` if Some(batch_requests.len()) == max_size { break; } ``` and it's called after the `batch_requests.len()` has become 1, it doesn't do anything to prevent more than 0 requests from being batched. Now we have cached batch in the server that is large than max_batch_size and `max_size - batch_size as usize` underflows. Signed-off-by: Max de Bayser * fix: update v3 scheduler and ensure max_batch_size > 0 --------- Signed-off-by: Max de Bayser Co-authored-by: Max de Bayser --- backends/v3/src/backend.rs | 3 ++- backends/v3/src/main.rs | 8 ++++++++ backends/v3/src/queue.rs | 7 +++++++ router/src/infer/v2/queue.rs | 7 +++++++ router/src/infer/v2/scheduler.rs | 4 ++-- 5 files changed, 26 insertions(+), 3 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index d82355de..6b3e0526 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -168,7 +168,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 21952e66..471ddb5a 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -150,6 +150,14 @@ async fn main() -> Result<(), RouterError> { } } + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + let (backend, _backend_info) = connect_backend( max_input_tokens, max_total_tokens, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 9427bd60..b457389c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -226,6 +226,13 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 0b51645a..696cbfc8 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -205,6 +205,13 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 3d6c36cf..cc333674 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -161,8 +161,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); - + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) From 4a16da5d4942a42def6953af1bf18a92690fc140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 11:42:00 +0200 Subject: [PATCH 201/340] Add FlashInfer support (#2354) This change adds support for FlashInfer. FlashInfer can be enabled using `FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`. Since this functionality is currently only for testing, FlashInfer is not installed anywhere yet. The FlashInfer API is quite different from FlashAttention/vLLM in that it requires more global bookkeeping: * A wrapper class needs to be contstructed (which we just call *state*). Since this is fairly expensive (due to pinned host memory allocation), we only do this once in a FlashCausalLM instance or for each CUDA Graph size. * Each model forward call needs to be wrapped in `begin_forward` and `end_forward`. This sets up data structures that can be reused for all calls to attention for that forward call. When calling attention, we need access to the state object. To avoid passing an argument down the call chain (which would require changes to all models), we use a context variable. Each model forward call is wrapped using a context manager that does all the bookkeeping for such a call: * Set the context variable to the forward call's state. * Call `begin_forward` on the state. * Yield. * Call `end_forward` on the state. * Reset the context variable. We cannot use a single shared global variable for this, since e.g. CUDA Graphs of different sizes each have their own state. --- .../layers/attention/common.py | 4 +- .../layers/attention/cuda.py | 46 ++++- .../layers/attention/flash_infer.py | 164 +++++++++++++++++ .../models/flash_causal_lm.py | 171 ++++++++++++++---- .../text_generation_server/models/globals.py | 5 + 5 files changed, 346 insertions(+), 44 deletions(-) create mode 100644 server/text_generation_server/layers/attention/flash_infer.py diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index bd0717ce..b986a082 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER import torch from typing import Optional -if FLASH_DECODING: +if FLASH_DECODING or FLASH_INFER: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 96b654d0..1b8e9209 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,10 @@ import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.models.globals import ( + FLASH_DECODING, + BLOCK_SIZE, + FLASH_INFER, +) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -23,7 +27,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -72,7 +76,16 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_DECODING: + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import decode_state + + return decode_state.get().forward( + query.contiguous(), + paged_kv_cache=(key_cache, value_cache), + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + elif FLASH_DECODING: max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -206,7 +219,32 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if V2: +if FLASH_INFER: + + def attention( + q, + k, + v, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + from text_generation_server.layers.attention.flash_infer import prefill_state + + return prefill_state.get().forward( + q, + k, + v, + causal=causal, + window_left=window_size_left, + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + +elif V2: def attention( q, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flash_infer.py new file mode 100644 index 00000000..56b53b2c --- /dev/null +++ b/server/text_generation_server/layers/attention/flash_infer.py @@ -0,0 +1,164 @@ +from typing import Optional +from contextvars import ContextVar +from contextlib import contextmanager + +import flashinfer +import torch + +prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( + "prefill_state" +) + +decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( + "decode_state" +) + +workspace: Optional[torch.Tensor] = None + + +def get_workspace(device): + """Get shared flashinfer workspace.""" + global workspace + if workspace is None: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + return workspace + + +def create_prefill_state( + *, + device: torch.device, +): + """Create a prefill state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_state( + *, + state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, + cu_seqlens: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + token = prefill_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + kv_indptr=cu_seqlens, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_state.reset(token) + + +def create_decode_state( + *, + device: torch.device, + num_heads: int, + num_kv_heads: int, +): + """Create a decode state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=False, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +def create_decode_state_cuda_graphs( + *, + device: torch.device, + block_tables: torch.Tensor, + block_tables_ptr: torch.Tensor, + last_page_len: torch.Tensor, + num_heads: int, + num_kv_heads: int, +): + """ + Create a decode state for use with CUDA Graphs. `block_tables`, + `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are + therefore stored as part of the state. + """ + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=True, + paged_kv_indices_buffer=block_tables, + paged_kv_indptr_buffer=block_tables_ptr, + paged_kv_last_page_len_buffer=last_page_len, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +@contextmanager +def use_decode_state( + *, + state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, + input_lengths: torch.Tensor, + block_tables: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer decoding state to the given + `state` and parameters. This state will be used by all calls to the + `paged_attention` function while the context manager is active. + """ + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = decode_state.set(state) + + try: + state.begin_forward( + indptr=indptr, + indices=block_tables, + last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=page_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + decode_state.reset(token) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36bb2662..12aa7dcd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext import math import os import time @@ -15,7 +16,7 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Iterable, Optional, Tuple, List, Type, Dict +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, FLASH_DECODING, + FLASH_INFER, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -907,6 +909,7 @@ class FlashCausalLM(Model): config.sliding_window = None self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -935,6 +938,21 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_prefill_state, + create_decode_state, + ) + + self.prefill_state = create_prefill_state(device=device) + + if not CUDA_GRAPHS: + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + super().__init__( model_id=model_id, model=model, @@ -972,7 +990,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: self.kv_cache = [ ( torch.empty( @@ -1044,38 +1062,66 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables.view(-1), + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + self.cuda_graphs[bs]["state"] = state + else: + state = None + torch.cuda.synchronize() # Run once outside to warmup - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths_, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -1295,23 +1341,28 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, + cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -1325,8 +1376,16 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - # Replay the graph - cuda_graph["graph"].replay() + state = cuda_graph.get("state") + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + # Replay the graph + cuda_graph["graph"].replay() + # Slice output to the correct shape speculative_logits = ( cuda_graph["speculative_logits"][:bs] @@ -1698,3 +1757,39 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + def _forward_context( + self, + *, + block_tables: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + input_lengths: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if not FLASH_INFER: + return nullcontext() + + from text_generation_server.layers.attention.flash_infer import ( + use_decode_state, + use_prefill_state, + ) + + if cu_seqlen_prefill is not None: + return use_prefill_state( + state=state if state is not None else self.prefill_state, + cu_seqlens=cu_seqlen_prefill, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) + else: + assert input_lengths is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths, + block_tables=block_tables.view(-1), + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d2431db..42b43c87 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,10 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} +if FLASH_INFER: + log_master(logger.info, "Using FLASH_INFER") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} @@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: log_master(logger.info, "Using FLASH_DECODING") + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: From dc0fa60f553203f31fc66dcf8222f337499b482a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 12:32:37 +0200 Subject: [PATCH 202/340] Add experimental flake (#2384) Add flake.nix --- flake.lock | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ flake.nix | 73 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..bf06240f --- /dev/null +++ b/flake.lock @@ -0,0 +1,99 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1723099294, + "narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=", + "owner": "danieldk", + "repo": "nixpkgs", + "rev": "45892b6ec142eaf300d777926983a433b5842c88", + "type": "github" + }, + "original": { + "owner": "danieldk", + "ref": "cudnn-9.3", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ], + "tgi-nix": "tgi-nix" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "tgi-nix": { + "inputs": { + "flake-compat": "flake-compat", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1723188417, + "narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=", + "owner": "danieldk", + "repo": "tgi-nix", + "rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3", + "type": "github" + }, + "original": { + "owner": "danieldk", + "repo": "tgi-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000..c17950c1 --- /dev/null +++ b/flake.nix @@ -0,0 +1,73 @@ +{ + inputs = { + tgi-nix.url = "github:danieldk/tgi-nix"; + nixpkgs.follows = "tgi-nix/nixpkgs"; + flake-utils.url = "github:numtide/flake-utils"; + }; + outputs = + { + self, + nixpkgs, + flake-utils, + tgi-nix, + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + config = { + allowUnfree = true; + cudaSupport = true; + }; + pkgs = import nixpkgs { + inherit config system; + overlays = [ tgi-nix.overlay ]; + }; + in + { + devShells.default = + with pkgs; + mkShell { + buildInputs = + [ + cargo + openssl.dev + pkg-config + ] + ++ (with python3.pkgs; [ + venvShellHook + pip + + einops + fbgemm-gpu + flash-attn + flash-attn-layer-norm + flash-attn-rotary + grpc-interceptor + grpcio-reflection + grpcio-status + hf-transfer + loguru + marlin-kernels + opentelemetry-api + opentelemetry-exporter-otlp + opentelemetry-instrumentation-grpc + opentelemetry-semantic-conventions + peft + tokenizers + torch + transformers + vllm + ]); + + venvDir = "./.venv"; + + postVenv = '' + unset SOURCE_DATE_EPOCH + ''; + postShellHook = '' + unset SOURCE_DATE_EPOCH + ''; + }; + } + ); +} From afa14b7595f808df9a79eb4433740a736d074dea Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 9 Aug 2024 14:25:44 +0200 Subject: [PATCH 203/340] Using HF_HOME instead of CACHE to get token read in addition to models. (#2288) --- Dockerfile | 2 +- Dockerfile_amd | 2 +- Dockerfile_intel | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index e6f01ec2..ce33cb19 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ RUN cargo build --release FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 diff --git a/Dockerfile_amd b/Dockerfile_amd index 51231638..cdad0d28 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -167,7 +167,7 @@ RUN python setup.py build FROM base AS base-copy # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 diff --git a/Dockerfile_intel b/Dockerfile_intel index d20f0a01..158c5a89 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -57,7 +57,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 From e9ba044250d3e2df49d05e3d710cfd51ce1abd13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 14:56:20 +0200 Subject: [PATCH 204/340] flake: add fmt and clippy (#2389) --- flake.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flake.nix b/flake.nix index c17950c1..a67879fc 100644 --- a/flake.nix +++ b/flake.nix @@ -30,8 +30,10 @@ buildInputs = [ cargo + clippy openssl.dev pkg-config + rustfmt ] ++ (with python3.pkgs; [ venvShellHook From 1d4a35a23cc45de215878584242402a23c6ce42e Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 9 Aug 2024 15:01:34 +0200 Subject: [PATCH 205/340] Update documentation for Supported models (#2386) * Minor doc fixes * up. * Other minor updates. --- docs/source/conceptual/quantization.md | 4 ++-- docs/source/quicktour.md | 2 +- docs/source/supported_models.md | 8 ++++---- server/text_generation_server/models/__init__.py | 6 +++--- update_doc.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 7507687f..a1ebe7e7 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -11,7 +11,7 @@ We recommend using the official quantization scripts for creating your quants: For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. -## Quantization with bitsandbytes +## Quantization with bitsandbytes, EETQ & fp8 bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. @@ -32,7 +32,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). -Use `eetq` or `fp8` for other quantization schemes. +Similarly you can use pass you can pass `--quantize eetq` or `--quantize fp8` for respective quantization schemes. In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 2313c69b..18e1a107 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -21,7 +21,7 @@ TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPU ## Consuming TGI -Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. +Once TGI is running, you can use the `generate` endpoint or the Open AI Chat Completion API compatible [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index b78104df..832f88ef 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -1,22 +1,22 @@ # Supported Models and Hardware -Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. +Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported. ## Supported Models - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) -- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) - [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) -- [Gemma2](https://huggingface.co/google/gemma2-9b) +- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) -- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) +- [Mistral](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407) - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) - [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) - [Phi](https://huggingface.co/microsoft/phi-1_5) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index da14d083..960b426b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -180,7 +180,7 @@ class ModelType(enum.Enum): LLAMA = { "type": "llama", "name": "Llama", - "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", } PHI3 = { "type": "phi3", @@ -200,7 +200,7 @@ class ModelType(enum.Enum): GEMMA2 = { "type": "gemma2", "name": "Gemma2", - "url": "https://huggingface.co/google/gemma2-9b", + "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", } COHERE = { "type": "cohere", @@ -220,7 +220,7 @@ class ModelType(enum.Enum): MISTRAL = { "type": "mistral", "name": "Mistral", - "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", + "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407", } MIXTRAL = { "type": "mixtral", diff --git a/update_doc.py b/update_doc.py index 428d4452..e887e1c6 100644 --- a/update_doc.py +++ b/update_doc.py @@ -7,7 +7,7 @@ import os TEMPLATE = """ # Supported Models and Hardware -Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. +Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported. ## Supported Models From df719fd52746ccbe603fa01c862f74387af9164b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 15:24:21 +0200 Subject: [PATCH 206/340] flake: use rust-overlay (#2390) --- flake.lock | 22 ++++++++++++++++++++++ flake.nix | 19 +++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/flake.lock b/flake.lock index bf06240f..043218e8 100644 --- a/flake.lock +++ b/flake.lock @@ -56,9 +56,31 @@ "tgi-nix", "nixpkgs" ], + "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1723170066, + "narHash": "sha256-SFkQfOA+8AIYJsPlQtxNP+z5jRLfz91z/aOrV94pPmw=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "fecfe4d7c96fea2982c7907997b387a6b52c1093", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, "systems": { "locked": { "lastModified": 1681028828, diff --git a/flake.nix b/flake.nix index a67879fc..fdd67d00 100644 --- a/flake.nix +++ b/flake.nix @@ -3,12 +3,17 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; }; outputs = { self, nixpkgs, flake-utils, + rust-overlay, tgi-nix, }: flake-utils.lib.eachDefaultSystem ( @@ -20,7 +25,10 @@ }; pkgs = import nixpkgs { inherit config system; - overlays = [ tgi-nix.overlay ]; + overlays = [ + rust-overlay.overlays.default + tgi-nix.overlay + ]; }; in { @@ -29,11 +37,14 @@ mkShell { buildInputs = [ - cargo - clippy openssl.dev pkg-config - rustfmt + (rust-bin.stable.latest.default.override { + extensions = [ + "rust-analyzer" + "rust-src" + ]; + }) ] ++ (with python3.pkgs; [ venvShellHook From 849bd93dc35f2efacbadbcdfb5639c77dbe9539a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 9 Aug 2024 16:41:17 +0200 Subject: [PATCH 207/340] Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385) * Using an enum for flash backens (paged/flashdecoding/flashinfer) * Early exit on server too. * Clippy. * Fix clippy and fmt. --- .gitignore | 1 + backends/v3/src/backend.rs | 16 ++++++---- docs/openapi.json | 2 +- docs/source/conceptual/quantization.md | 4 +-- launcher/src/main.rs | 2 +- router/src/infer/v2/scheduler.rs | 18 ++++++++---- router/src/lib.rs | 29 +++++++++++++++++++ .../layers/attention/common.py | 4 +-- .../layers/attention/cuda.py | 11 ++++--- .../layers/attention/rocm.py | 4 +-- .../models/flash_causal_lm.py | 11 ++++--- .../text_generation_server/models/globals.py | 14 ++++----- 12 files changed, 78 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 0de8b848..bd9d9125 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp data/ load_tests/*.json +server/fbgemmm diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 6b3e0526..68ddf00b 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{FinishReason, PrefillToken, Token}; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -35,12 +35,18 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 }; - let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, diff --git a/docs/openapi.json b/docs/openapi.json index 657d6d19..4d3a2398 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +} diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index a1ebe7e7..b7672a9f 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants: 2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) 3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) -For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. ## Quantization with bitsandbytes, EETQ & fp8 @@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). diff --git a/launcher/src/main.rs b/launcher/src/main.rs index dd4a92ec..4c8f288a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1475,7 +1475,7 @@ fn main() -> Result<(), LauncherError> { if config.model_type == Some("gemma2".to_string()) { tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("FLASH_DECODING", "1"); + std::env::set_var("ATTENTION", "flashdecoding"); } let config: Config = config.into(); diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index cc333674..0e5fc8a3 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,10 +1,10 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, + Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; +use crate::{Attention, FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -40,12 +40,18 @@ impl BackendV2 { generation_health: Arc, ) -> Self { // Infer shared state - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .expect(&format!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 }; - let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/router/src/lib.rs b/router/src/lib.rs index a956b058..66738706 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -15,6 +15,35 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(PartialEq)] +pub enum Attention { + Paged, + FlashDecoding, + FlashInfer, +} + +#[derive(Debug)] +pub struct ParseError; + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cannot parse attention value") + } +} +impl std::error::Error for ParseError {} + +impl std::str::FromStr for Attention { + type Err = ParseError; + fn from_str(s: &str) -> Result { + match s { + "paged" => Ok(Attention::Paged), + "flashdecoding" => Ok(Attention::FlashDecoding), + "flashinfer" => Ok(Attention::FlashInfer), + _ => Err(ParseError), + } + } +} + #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")] diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index b986a082..f162230c 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER +from text_generation_server.models.globals import ATTENTION import torch from typing import Optional -if FLASH_DECODING or FLASH_INFER: +if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 1b8e9209..d039e1e7 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,9 +1,8 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( - FLASH_DECODING, + ATTENTION, BLOCK_SIZE, - FLASH_INFER, ) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -27,7 +26,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -76,7 +75,7 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import decode_state return decode_state.get().forward( @@ -85,7 +84,7 @@ def paged_attention( logits_soft_cap=softcap, sm_scale=softmax_scale, ) - elif FLASH_DECODING: + elif ATTENTION == "flashdecoding": max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -219,7 +218,7 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if FLASH_INFER: +if ATTENTION == "flashinfer": def attention( q, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69e64162..16ce8d2b 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import ATTENTION from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master from loguru import logger @@ -28,7 +28,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if ATTENTION == "flashdecoding": shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 12aa7dcd..21b66a68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -40,8 +40,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, - FLASH_DECODING, - FLASH_INFER, + ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -938,7 +937,7 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_prefill_state, create_decode_state, @@ -990,7 +989,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: self.kv_cache = [ ( torch.empty( @@ -1062,7 +1061,7 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_decode_state_cuda_graphs, ) @@ -1766,7 +1765,7 @@ class FlashCausalLM(Model): input_lengths: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: - if not FLASH_INFER: + if ATTENTION != "flashinfer": return nullcontext() from text_generation_server.layers.attention.flash_infer import ( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 42b43c87..b58a5b80 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,16 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} -if FLASH_INFER: - log_master(logger.info, "Using FLASH_INFER") +ATTENTION = os.getenv("ATTENTION", "paged") +_expected = {"paged", "flashdecoding", "flashinfer"} +assert ( + ATTENTION in _expected +), f"Attention is not valid {ATTENTION}, expected {_expected}" +log_master(logger.info, f"Using Attention = {ATTENTION}") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli -FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 -if FLASH_DECODING: - log_master(logger.info, "Using FLASH_DECODING") +BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 cuda_graphs = os.getenv("CUDA_GRAPHS") From 959add5e9bdd2441232a49c408e147165e04847f Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 10:56:45 -0400 Subject: [PATCH 208/340] feat: add guideline to chat request and template (#2391) * feat: add guideline to chat request and template * fix: add template test and update docs --- docs/openapi.json | 7 +++++++ router/src/infer/chat_template.rs | 15 +++++++++++++++ router/src/infer/mod.rs | 3 ++- router/src/lib.rs | 6 ++++++ router/src/server.rs | 9 +++++++-- 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 4d3a2398..3a4c76bd 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -819,6 +819,13 @@ "example": "1.0", "nullable": true }, + "guideline": { + "type": "string", + "description": "A guideline to be used in the chat_template", + "default": "null", + "example": "null", + "nullable": true + }, "logit_bias": { "type": "array", "items": { diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 24a00352..7c2753ed 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -48,6 +48,7 @@ impl ChatTemplate { pub(crate) fn apply( &self, + guideline: Option<&str>, mut messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { @@ -65,6 +66,7 @@ impl ChatTemplate { self.template .render(ChatTemplateInputs { + guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), @@ -731,6 +733,19 @@ mod tests { }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", }, + ChatTemplateTestItem { + name: "google/shieldgemma-9b", + chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + guideline: Some("Do not use offensive language."), + ..Default::default() + }, + target: "You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n\nHuman Question: I'd like to show off how chat templating works!\n\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n", + }, ]; #[allow(unused_variables)] // name is unused diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 534a2647..58d5cf9a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -138,13 +138,14 @@ impl Infer { #[instrument(skip_all)] pub(crate) fn apply_chat_template( &self, + guideline: Option, messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, grammar_with_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index 66738706..0a15c495 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -858,6 +858,11 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, + + /// A guideline to be used in the chat_template + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub guideline: Option, } fn default_tool_prompt() -> Option { @@ -965,6 +970,7 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, tools: Option<&'a str>, tools_prompt: Option<&'a str>, + guideline: Option<&'a str>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] diff --git a/router/src/server.rs b/router/src/server.rs index 1d1cd36a..8c0bd762 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -141,6 +141,7 @@ async fn get_chat_tokenize( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -151,6 +152,7 @@ async fn get_chat_tokenize( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -1123,6 +1125,7 @@ async fn chat_completions( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -1142,6 +1145,7 @@ async fn chat_completions( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -2402,6 +2406,7 @@ fn prepare_chat_input( tools: Option>, tool_choice: ToolChoice, tool_prompt: &str, + guideline: Option, messages: Vec, ) -> Result { if response_format.is_some() && tools.is_some() { @@ -2411,7 +2416,7 @@ fn prepare_chat_input( } if let Some(format) = response_format { - let inputs = infer.apply_chat_template(messages, None)?; + let inputs = infer.apply_chat_template(guideline, messages, None)?; return Ok((inputs, Some(format), None)); } @@ -2423,6 +2428,6 @@ fn prepare_chat_input( let tools_grammar_prompt = tool_grammar .as_ref() .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; Ok((inputs, grammar, tool_grammar)) } From bb833389e04eac1637029c10ea00139e218288d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 22:36:51 +0200 Subject: [PATCH 209/340] Update flake for 9.0a capability in Torch (#2394) --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 043218e8..7889e7cf 100644 --- a/flake.lock +++ b/flake.lock @@ -102,11 +102,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1723188417, - "narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=", + "lastModified": 1723234585, + "narHash": "sha256-HChJpNP155FPhHr9C5BtqllV8Uv/Ebg59HhMc/HhQrc=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3", + "rev": "15bd4a978d6c2c8b04b7afc335d137dbe41e73df", "type": "github" }, "original": { From 197dd3af12a1453f83fb530bcf6e34752b35f0a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 12 Aug 2024 09:28:38 +0200 Subject: [PATCH 210/340] nix: add router to the devshell (#2396) --- flake.nix | 4 ++++ router.nix | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 router.nix diff --git a/flake.nix b/flake.nix index fdd67d00..1d14561f 100644 --- a/flake.nix +++ b/flake.nix @@ -70,6 +70,10 @@ torch transformers vllm + + (callPackage ./router.nix { + inherit (rustPlatform) buildRustPackage importCargoLock; + }) ]); venvDir = "./.venv"; diff --git a/router.nix b/router.nix new file mode 100644 index 00000000..eeeac199 --- /dev/null +++ b/router.nix @@ -0,0 +1,18 @@ +{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: + +buildRustPackage { + name = "text-generation-router"; + + src = ./.; + + sourceDir = ./backends/v3; + + cargoLock = { + lockFile = ./Cargo.lock; + }; + + nativeBuildInputs = [ pkg-config ]; + + buildInputs = [ openssl.dev protobuf ]; + +} From 8750dc878ecd53aa2d3fc957f621136b419b19ab Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 14:08:38 +0200 Subject: [PATCH 211/340] Upgrade fbgemm (#2398) * Upgrade fbgemm * Fix fbgemm version --- .../models/test_grammar_response_format_llama.py | 4 ++-- server/Makefile-fbgemm | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index ea25fa1c..25bf9d98 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -98,6 +98,6 @@ async def test_grammar_response_format_llama_error_if_tools_not_installed( # 422 means the server was unable to process the request because it contains invalid data. assert response.status_code == 422 assert response.json() == { - "error": "Grammar and tools are mutually exclusive", - "error_type": "grammar and tools", + "error": "Tool error: Grammar and tools are mutually exclusive", + "error_type": "tool_error", } diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm index 57526577..5f3c0eaa 100644 --- a/server/Makefile-fbgemm +++ b/server/Makefile-fbgemm @@ -1,4 +1,4 @@ -fbgemm_commit := ddac8dd9fc0bee70a3f456df68b8aac38576c856 +fbgemm_commit := v0.8.0 build-fbgemm: git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ From fbe59c6267e48c04962ca7053d0f0f0eb00346c4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 14:08:46 +0200 Subject: [PATCH 212/340] Adding launcher to build. (#2397) --- _launcher.nix | 18 ++++++++++++++++++ flake.nix | 3 +++ 2 files changed, 21 insertions(+) create mode 100644 _launcher.nix diff --git a/_launcher.nix b/_launcher.nix new file mode 100644 index 00000000..1acae7e1 --- /dev/null +++ b/_launcher.nix @@ -0,0 +1,18 @@ +{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: + +buildRustPackage { + name = "text-generation-lancher"; + + src = ./.; + + sourceDir = ./launcher; + + cargoLock = { + lockFile = ./Cargo.lock; + }; + + nativeBuildInputs = [ pkg-config ]; + + buildInputs = [ openssl.dev protobuf ]; + +} diff --git a/flake.nix b/flake.nix index 1d14561f..761c4af8 100644 --- a/flake.nix +++ b/flake.nix @@ -74,6 +74,9 @@ (callPackage ./router.nix { inherit (rustPlatform) buildRustPackage importCargoLock; }) + (callPackage ./_launcher.nix { + inherit (rustPlatform) buildRustPackage importCargoLock; + }) ]); venvDir = "./.venv"; From 1daaddd072bbad00325d8eadcd474244a4603d10 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 14:08:59 +0200 Subject: [PATCH 213/340] Fixing import exl2 (#2399) --- .../layers/gptq/__init__.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index f6616d3e..9c9b69d1 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -8,34 +8,6 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, # noqa: F401 - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass - @dataclass class GPTQWeight(Weight): @@ -432,3 +404,33 @@ class GPTQWeightsLoader(WeightsLoader): else False ) self.quant_method = "gptq" + + +# Needs to be at the end because circular import. +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass From b8efd6d00c62174bfc9d020775e2329fe10b6b01 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 12 Aug 2024 20:10:30 +0800 Subject: [PATCH 214/340] Cpu dockerimage (#2367) add intel-cpu docker image Signed-off-by: Wang, Yi A --- .github/workflows/ci_build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 5ca2854a..6000cec3 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -37,7 +37,7 @@ jobs: # fail-fast is true by default fail-fast: false matrix: - hardware: ["cuda", "rocm", "intel"] + hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"] uses: ./.github/workflows/build.yaml # calls the one above ^ with: hardware: ${{ matrix.hardware }} From f586cc7f0ca895e250eea08637eeecbd9a4c8a23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 12 Aug 2024 14:59:17 +0200 Subject: [PATCH 215/340] Add support for prefix caching to the v3 router (#2392) This change adds support for prefix caching to the v3 router. This is broken up from the backend support to ease reviewing. For now prefix caching is only enabled with `USE_PREFIX_CACHING=1` in this case, the router will switch to `RadixAllocator`. This allocator uses a radix trie to keep track of prefills that were seen prior. If a new prefill is a prefix of a previously-seen prefil, the router will send a request with `prefix_len>0`, which can be used by the backend to decide to reuse KV blocks from the cache, rather than recomputing them. Even though backend support is not added in this PR, the backend will still work with prefix caching enabled. The prefix lengths are just ignored and not used. --- Cargo.lock | 1 + backends/client/src/v3/client.rs | 1 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/Cargo.toml | 1 + backends/v3/src/backend.rs | 10 + backends/v3/src/block_allocator.rs | 182 +++-- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/sharded_client.rs | 1 + backends/v3/src/lib.rs | 1 + backends/v3/src/queue.rs | 53 +- backends/v3/src/radix.rs | 755 ++++++++++++++++++ benchmark/src/generation.rs | 1 + proto/v3/generate.proto | 356 ++++----- router/src/validation.rs | 19 +- .../text_generation_server/models/globals.py | 17 +- 15 files changed, 1145 insertions(+), 255 deletions(-) create mode 100644 backends/v3/src/radix.rs diff --git a/Cargo.lock b/Cargo.lock index c4cf8dc8..6cadbb45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4091,6 +4091,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "slotmap", "text-generation-router", "thiserror", "tokenizers", diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index a996b14f..b321278c 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -156,6 +156,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index ae8a899b..1cc173e3 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 5d9a140b..129ceb9c 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -33,6 +33,7 @@ rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" +slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 68ddf00b..cbcbff72 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,15 +35,24 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { + let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { + matches!(prefix_caching.as_str(), "true" | "1") + } else { + false + }; let attention = if let Ok(attention) = std::env::var("ATTENTION") { attention .parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + } else if prefix_caching { + Attention::FlashInfer } else { Attention::Paged }; let block_size = if attention == Attention::FlashDecoding { 256 + } else if attention == Attention::FlashInfer { + 1 } else { 16 }; @@ -51,6 +60,7 @@ impl BackendV3 { let queue = Queue::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 7467fd85..05c2bd30 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,16 +1,26 @@ -use std::cmp::min; +use std::{cmp::min, sync::Arc}; use tokio::sync::{mpsc, oneshot}; +use crate::radix::RadixAllocator; + #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { + pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, - block_allocator: BlockAllocator, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } } } @@ -24,6 +34,7 @@ impl BlockAllocator { pub(crate) fn new( max_batch_total_tokens: u32, block_size: u32, + prefix_caching: bool, window_size: Option, ) -> Self { // Create channel @@ -33,6 +44,7 @@ impl BlockAllocator { tokio::spawn(block_allocator_task( max_batch_total_tokens / block_size, block_size, + prefix_caching, window_size, receiver, )); @@ -42,28 +54,32 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -71,54 +87,29 @@ impl BlockAllocator { async fn block_allocator_task( blocks: u32, block_size: u32, + prefix_caching: bool, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); } } } @@ -128,9 +119,92 @@ async fn block_allocator_task( enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, }, } + +pub(crate) trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} + +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index c407687b..6282759e 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -157,6 +157,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index afb13cdc..2f78da03 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -245,6 +245,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index a6f89169..c8fc55f8 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -2,6 +2,7 @@ mod backend; mod block_allocator; mod client; mod queue; +mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index b457389c..13544235 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -46,6 +46,7 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -57,6 +58,7 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -109,6 +111,7 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -117,6 +120,7 @@ async fn queue_task( let mut state = State::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -176,12 +180,19 @@ impl State { fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, ) -> Self { - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { entries: VecDeque::with_capacity(128), @@ -305,7 +316,10 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids.clone()) + .await + { None => { // Entry is over budget // Add it back to the front @@ -331,11 +345,12 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -372,6 +387,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + prefix_len, adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time @@ -480,6 +496,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -492,6 +510,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, truncate: 0, decoder_input_details: false, @@ -527,7 +546,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -543,7 +562,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -551,7 +570,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -583,7 +602,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -603,7 +622,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -636,14 +655,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -651,7 +670,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -684,7 +703,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -700,7 +719,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -725,7 +744,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -744,7 +763,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 00000000..0464b9f8 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,755 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use slotmap::{DefaultKey, SlotMap}; + +use crate::block_allocator::{Allocator, BlockAllocation}; + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); + if window_size.is_some() { + unimplemented!("Window size not supported in the prefix-caching block allocator yet"); + } + + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + + node_id + } else { + self.cache_blocks.root_id() + }; + + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len(); + let suffix_len = tokens - prefix_len as u32; + + match self.alloc_or_reclaim(suffix_len as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = blocks.clone(); + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let prefix_len = self + .cache_blocks + .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks + .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + } + + // Free non-prefill blocks. + self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, + BlockTokenCountMismatch, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new() -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could beevicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks - evicted.len(); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + node.key.truncate(node.blocks.len() - blocks_needed); + evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + self.insert_(self.root, tokens, blocks) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + if tokens.len() != blocks.len() { + return Err(TrieError::BlockTokenCountMismatch); + } + + if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = child.key.shared_prefix_len(tokens); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let mut parent_blocks = node.blocks.split_off(prefix_len); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = node.key[0]; + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = key[0]; + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(first, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + parent.children.remove(&node.key[0]); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + self.nodes.remove(node_id); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +/// Helper trait to get the length of the shared prefix of two sequences. +trait SharedPrefixLen { + fn shared_prefix_len(&self, other: &Self) -> usize; +} + +impl SharedPrefixLen for [T] +where + T: PartialEq, +{ + fn shared_prefix_len(&self, other: &Self) -> usize { + self.iter().zip(other).take_while(|(a, b)| a == b).count() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::block_allocator::Allocator; + + use super::RadixAllocator; + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = super::RadixTrie::new(); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703..7494d5b5 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + prefix_len: 0, adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 926c878e..68eea7ac 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -3,22 +3,23 @@ syntax = "proto3"; package generate.v3; service TextGenerationService { - /// Model Info - rpc Info (InfoRequest) returns (InfoResponse) {} - /// Service discovery - rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} - /// Empties batch cache - rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Remove requests from a cached batch - rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); - /// Warmup the model and compute max cache size - rpc Warmup (WarmupRequest) returns (WarmupResponse); - /// Prefill batch and decode first token - rpc Prefill (PrefillRequest) returns (PrefillResponse); - /// Decode token for a list of prefilled batches - rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Health check - rpc Health (HealthRequest) returns (HealthResponse); + /// Model Info + rpc Info(InfoRequest) returns (InfoResponse) {} + /// Service discovery + rpc ServiceDiscovery(ServiceDiscoveryRequest) + returns (ServiceDiscoveryResponse) {} + /// Empties batch cache + rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup(WarmupRequest) returns (WarmupResponse); + /// Prefill batch and decode first token + rpc Prefill(PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode(DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health(HealthRequest) returns (HealthResponse); } message HealthRequest {} @@ -28,240 +29,239 @@ message HealthResponse {} message InfoRequest {} message InfoResponse { - bool requires_padding = 1; - string dtype = 2; - string device_type = 3; - optional uint32 window_size = 4; - uint32 speculate = 5; + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; + optional uint32 window_size = 4; + uint32 speculate = 5; } /// Empty request message ServiceDiscoveryRequest {} message ServiceDiscoveryResponse { - /// Other shards urls - repeated string urls = 1; + /// Other shards urls + repeated string urls = 1; } message ClearCacheRequest { - /// Optional batch id - optional uint64 id = 1; + /// Optional batch id + optional uint64 id = 1; } /// Empty response message ClearCacheResponse {} message Image { - /// Binary image data. - bytes data = 1; + /// Binary image data. + bytes data = 1; - /// Image MIME type. - string mimetype = 2; + /// Image MIME type. + string mimetype = 2; } message InputChunk { - oneof chunk { - /// Plain text data - string text = 1; - /// Image data - Image image = 2; - } + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } } -message Input { - repeated InputChunk chunks = 1; - } +message Input { repeated InputChunk chunks = 1; } enum GrammarType { - GRAMMAR_TYPE_NONE = 0; - GRAMMAR_TYPE_JSON = 1; - GRAMMAR_TYPE_REGEX = 2; + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; } message NextTokenChooserParameters { - /// exponential scaling output probability distribution - float temperature = 1; - /// restricting to the k highest probability elements - uint32 top_k = 2; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float top_p = 3; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float typical_p = 4; - /// apply sampling on the logits - bool do_sample = 5; - /// random seed for sampling - uint64 seed = 6; - /// repetition penalty - float repetition_penalty = 7; - /// frequency penalty - float frequency_penalty = 9; - /// token watermarking using "A Watermark for Large Language Models" - bool watermark = 8; - /// grammar (applied if not empty) - string grammar = 10; - /// grammar type - GrammarType grammar_type = 11; + /// exponential scaling output probability distribution + float temperature = 1; + /// restricting to the k highest probability elements + uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; + /// apply sampling on the logits + bool do_sample = 5; + /// random seed for sampling + uint64 seed = 6; + /// repetition penalty + float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { - /// Maximum number of generated tokens - uint32 max_new_tokens = 1; - /// Optional stopping sequences - repeated string stop_sequences = 2; - /// Ignore end of sequence token - /// used for benchmarking - bool ignore_eos_token = 3; + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; + /// Ignore end of sequence token + /// used for benchmarking + bool ignore_eos_token = 3; } message Request { - /// Request ID - uint64 id = 1; - /// The generation context as chunks - Input input_chunks = 8; - /// The generation context, stringified input_chunks - string inputs = 2; - /// Context truncation - uint32 truncate = 3; - /// Next Token Chooser Parameters - NextTokenChooserParameters parameters = 4; - /// Stopping Criteria Parameters - StoppingCriteriaParameters stopping_parameters = 5; - /// Return prefill logprobs - bool prefill_logprobs = 6; - /// Return most likely n tokens - uint32 top_n_tokens = 7; - /// Paged attention blocks - repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; - /// LORA adapter index - optional string adapter_id = 11; + /// Request ID + uint64 id = 1; + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks + string inputs = 2; + /// Context truncation + uint32 truncate = 3; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; + /// Paged attention blocks + repeated uint32 blocks = 9; + /// Paged attention slots + repeated uint32 slots = 10; + /// LORA adapter index + optional string adapter_id = 11; + /// Prefix length that can be retrieved from the KV cache. + uint32 prefix_len = 12; } message Batch { - /// Batch ID - uint64 id = 1; - /// Individual requests - repeated Request requests = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; - /// Maximum number of Paged Attention blocks - uint32 max_blocks = 5; + /// Batch ID + uint64 id = 1; + /// Individual requests + repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; + /// Maximum number of Paged Attention blocks + uint32 max_blocks = 5; } message CachedBatch { - /// Batch ID - uint64 id = 1; - /// Individual requests ids - repeated uint64 request_ids = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; } enum FinishReason { - FINISH_REASON_LENGTH = 0; - FINISH_REASON_EOS_TOKEN = 1; - FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; } message GeneratedText { - /// Output - string text = 1; - /// Number of generated tokens - uint32 generated_tokens = 2; - /// Finish reason - FinishReason finish_reason = 3; - /// Seed - optional uint64 seed = 4; + /// Output + string text = 1; + /// Number of generated tokens + uint32 generated_tokens = 2; + /// Finish reason + FinishReason finish_reason = 3; + /// Seed + optional uint64 seed = 4; } message Tokens { - /// Token IDs - repeated uint32 ids = 1; - /// Logprobs - repeated float logprobs = 2; - /// tokens - repeated string texts = 3; - /// special - repeated bool is_special = 4; + /// Token IDs + repeated uint32 ids = 1; + /// Logprobs + repeated float logprobs = 2; + /// tokens + repeated string texts = 3; + /// special + repeated bool is_special = 4; } message Generation { - /// Request ID - uint64 request_id = 1; - /// Prefill tokens (optional) - Tokens prefill_tokens = 2; - Tokens tokens = 3; - /// Complete generated text - optional GeneratedText generated_text = 4; - /// Top tokens - repeated Tokens top_tokens = 5; + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + Tokens prefill_tokens = 2; + Tokens tokens = 3; + /// Complete generated text + optional GeneratedText generated_text = 4; + /// Top tokens + repeated Tokens top_tokens = 5; } message FilterBatchRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to keep - repeated uint64 request_ids = 2; + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated uint64 request_ids = 2; } message FilterBatchResponse { - /// Filtered Batch (cached) - CachedBatch batch = 1; + /// Filtered Batch (cached) + CachedBatch batch = 1; } - message PrefillRequest { - /// Batch - Batch batch = 1; + /// Batch + Batch batch = 1; } message PrefillResponse { - /// Generation - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; + /// Generation + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; } message DecodeRequest { - /// Cached batches - repeated CachedBatch batches = 1; + /// Cached batches + repeated CachedBatch batches = 1; } message DecodeResponse { - /// Decodes - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; - /// Concatenate elapsed time in nanoseconds - optional uint64 concat_ns = 6; + /// Decodes + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message WarmupRequest { - /// Batch to warmup on - Batch batch = 1; - uint32 max_input_length = 2; - uint32 max_prefill_tokens = 3; - uint32 max_total_tokens = 4; + /// Batch to warmup on + Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; } message WarmupResponse { - /// Maximum number of tokens supported by the model - optional uint32 max_supported_total_tokens = 1; + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; } diff --git a/router/src/validation.rs b/router/src/validation.rs index 3d1a4103..5011158a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; +use std::sync::Arc; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -115,13 +116,14 @@ impl Validation { } } + #[allow(clippy::type_complexity)] #[instrument(skip(self, inputs))] async fn validate_input( &self, inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -156,8 +158,10 @@ impl Validation { )); } + let input_ids = encoding.get_ids()[..input_length].to_owned(); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, input_length, max_new_tokens)) + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } // Return inputs without validation else { @@ -180,7 +184,12 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs)], + None, + input_length, + max_new_tokens, + )) } } @@ -314,7 +323,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -391,6 +400,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_ids: input_ids.map(Arc::new), decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -707,6 +717,7 @@ pub struct ValidStoppingParameters { #[derive(Debug, Clone)] pub struct ValidGenerateRequest { pub inputs: Vec, + pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b58a5b80..abc35421 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,29 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -ATTENTION = os.getenv("ATTENTION", "paged") +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) +log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") + +ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") +if PREFIX_CACHING and ATTENTION != "flashinfer": + raise RuntimeError("Prefix caching is only supported with flashinfer") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + # This is overridden by the cli -BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 +BLOCK_SIZE: int +if ATTENTION == "flashdecoding": + BLOCK_SIZE = 256 +elif ATTENTION == "flashinfer": + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 cuda_graphs = os.getenv("CUDA_GRAPHS") From 6393cdee630223d17eb32aef30c73840d8eb1b49 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 15:22:02 +0200 Subject: [PATCH 216/340] Keeping the benchmark somewhere (#2401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniël de Kok --- Cargo.lock | 136 +++++++++++++++++++++++++++- backends/v3/Cargo.toml | 22 ++++- backends/v3/benches/prefix_cache.rs | 47 ++++++++++ backends/v3/src/block_allocator.rs | 6 +- backends/v3/src/lib.rs | 4 +- backends/v3/src/queue.rs | 2 +- backends/v3/src/radix.rs | 5 + router/Cargo.toml | 20 +++- 8 files changed, 224 insertions(+), 18 deletions(-) create mode 100644 backends/v3/benches/prefix_cache.rs diff --git a/Cargo.lock b/Cargo.lock index 6cadbb45..405922d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,6 +180,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -697,6 +708,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1027,7 +1074,7 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", "flume", - "half", + "half 2.4.1", "lebe", "miniz_oxide", "rayon-core", @@ -1328,6 +1375,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -1365,6 +1418,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1865,6 +1927,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1905,7 +1976,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap", + "clap 4.5.11", "fancy-regex", "fraction", "getrandom", @@ -2199,7 +2270,7 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -2467,7 +2538,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2523,6 +2594,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openssl" version = "0.10.66" @@ -2850,6 +2927,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.13" @@ -3593,6 +3698,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.204" @@ -4068,13 +4183,15 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.11", + "criterion", "futures", "futures-util", "grpc-metadata", "hf-hub", "image", "init-tracing-opentelemetry", + "itertools 0.13.0", "jsonschema", "metrics", "metrics-exporter-prometheus", @@ -4108,6 +4225,15 @@ dependencies = [ "utoipa-swagger-ui", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.63" diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 129ceb9c..06a44bec 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -35,8 +35,14 @@ serde = "1.0.188" serde_json = "1.0.107" slotmap = "1.0.7" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" @@ -44,7 +50,9 @@ tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -60,8 +68,16 @@ tower = "^0.4" tonic-build = "0.10.1" prost-build = "0.12.1" +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + [features] default = ["ngrok"] ngrok = ["text-generation-router/ngrok"] google = ["text-generation-router/google"] kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 00000000..d9df45b2 --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::Rng; + +use text_generation_router_v3::block_allocator::Allocator; +use text_generation_router_v3::radix::RadixAllocator; + +fn prefix_cache_benchmark(c: &mut Criterion) { + // let prefixes: Vec> = (0..8192) + // .chunks(256) + // .into_iter() + // .map(|c| c.collect()) + // .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("Radix allocator", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate( + prefill.len() as u32 + 13, + Some(Arc::new(black_box(prefill))), + ); + if let Some(alloc) = alloc { + cache.free(alloc.blocks.clone(), alloc.allocation_id); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 05c2bd30..c5503b8c 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -4,7 +4,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; #[derive(Debug, Clone)] -pub(crate) struct BlockAllocation { +pub struct BlockAllocation { pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, @@ -25,7 +25,7 @@ impl Drop for BlockAllocation { } #[derive(Debug, Clone)] -pub(crate) struct BlockAllocator { +pub struct BlockAllocator { /// Channel to communicate with the background task block_allocator: mpsc::UnboundedSender, } @@ -128,7 +128,7 @@ enum BlockAllocatorCommand { }, } -pub(crate) trait Allocator { +pub trait Allocator { fn allocate( &mut self, tokens: u32, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index c8fc55f8..77a9a11a 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -1,8 +1,8 @@ mod backend; -mod block_allocator; +pub mod block_allocator; mod client; mod queue; -mod radix; +pub mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 13544235..0fb05a98 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -250,7 +250,7 @@ impl State { // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); + next_batch_span.follows_from(Span::current()); let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_entries = diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 0464b9f8..ef963532 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,11 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } +impl Default for RadixTrie { + fn default() -> Self { + Self::new() + } +} impl RadixTrie { /// Construct a new radix trie. diff --git a/router/Cargo.toml b/router/Cargo.toml index 1be74546..7773e212 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -27,8 +27,14 @@ reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.40" @@ -37,7 +43,9 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -46,7 +54,11 @@ once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" -uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +uuid = { version = "1.9.1", default-features = false, features = [ + "v4", + "fast-rng", + "macro-diagnostics", +] } csv = "1.3.0" ureq = "=2.9" From 8e6bfa2fc59ce4f03ca1faf7e27b04a1b7759078 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 10:58:40 -0400 Subject: [PATCH 217/340] =?UTF-8?q?feat:=20validate=20template=20variables?= =?UTF-8?q?=20before=20apply=20and=20improve=20sliding=20wi=E2=80=A6=20(#2?= =?UTF-8?q?403)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: validate template variables before apply and improve sliding window check * fix: improve missing template var test --- router/src/infer/chat_template.rs | 61 +++++++++++++++++-- router/src/infer/mod.rs | 3 + router/src/server.rs | 1 + router/src/validation.rs | 2 +- .../text_generation_server/models/__init__.py | 18 +++--- 5 files changed, 70 insertions(+), 15 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 7c2753ed..ef4beee2 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use crate::infer::InferError; use crate::{ ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, @@ -16,6 +18,7 @@ pub(crate) struct ChatTemplate { bos_token: Option, eos_token: Option, use_default_tool_template: bool, + variables: HashSet, } impl ChatTemplate { @@ -30,19 +33,22 @@ impl ChatTemplate { let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) .template_from_str(Box::leak(template_str)) .unwrap(); + // get the list of variables that are used in the template + let variables = template.undeclared_variables(true); + // check if the `tools` variable is used in the template + let use_default_tool_template = !variables.contains("tools"); + Self { template, bos_token: bos_token.map(|token| token.as_str().to_string()), eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, + variables, } } @@ -64,6 +70,11 @@ impl ChatTemplate { let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + // check if guideline is expected but not provided + if self.variables.contains("guideline") && guideline.is_none() { + return Err(InferError::MissingTemplateVariable("guideline".to_string())); + } + self.template .render(ChatTemplateInputs { guideline, @@ -82,7 +93,8 @@ impl ChatTemplate { #[cfg(test)] mod tests { use crate::infer::chat_template::raise_exception; - use crate::{ChatTemplateInputs, TextMessage}; + use crate::infer::ChatTemplate; + use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken}; use minijinja::Environment; #[test] @@ -770,4 +782,45 @@ mod tests { assert_eq!(result, target); } } + + #[test] + fn test_chat_template_invalid_with_guideline() { + let ct = ChatTemplate::new( + "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText( + "I'm doing great. How can I help you today?".to_string(), + ), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Hello, how are you?".to_string()), + }, + ]; + + let result = ct.apply(None, msgs, None); + + match result { + Ok(_) => panic!("Should have failed since no guideline is provided"), + Err(e) => { + assert_eq!(e.to_string(), "Missing template vatiable: guideline") + } + } + } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 58d5cf9a..c9354d9a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -337,6 +337,8 @@ pub enum InferError { IncompleteGeneration, #[error("Template error: {0}")] TemplateError(#[from] minijinja::Error), + #[error("Missing template vatiable: {0}")] + MissingTemplateVariable(String), #[error("Tool error: {0}")] ToolError(String), } @@ -349,6 +351,7 @@ impl InferError { InferError::ValidationError(_) => "validation", InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", + InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::ToolError(_) => "tool_error", } } diff --git a/router/src/server.rs b/router/src/server.rs index 8c0bd762..99ec077f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2297,6 +2297,7 @@ impl From for (StatusCode, Json) { InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index 5011158a..0024723c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -830,7 +830,7 @@ mod tests { .await { // Err(ValidationError::MaxNewTokens(1, 10)) => (), - Ok((_s, 0, 10)) => (), + Ok((_s, _, 0, 10)) => (), r => panic!("Unexpected not max new tokens: {r:?}"), } } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 960b426b..4fa9e66d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -497,17 +497,15 @@ def get_model( else -1 ) - if max_input_tokens is not None and max_input_tokens <= sliding_window: - sliding_window = -1 + should_use_sliding_window = ( + sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING + ) - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_input_tokens > sliding_window - ): - raise ValueError( - f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." - ) + if should_use_sliding_window: + if max_input_tokens is not None and max_input_tokens > sliding_window: + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." + ) if model_type == DEEPSEEK_V2: if FLASH_ATTENTION: From 3079865b60dae4dd0a364690cd3fcc57d4908bad Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 11:24:32 -0400 Subject: [PATCH 218/340] fix: allocate tmp based on sgmv kernel if available (#2345) * fix: allocate tmp based on sgmv kernel if available * fix: re add copy build artifacts step for punica kernels --- server/text_generation_server/utils/sgmv.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py index e0aec25f..2d0a73a5 100644 --- a/server/text_generation_server/utils/sgmv.py +++ b/server/text_generation_server/utils/sgmv.py @@ -151,13 +151,17 @@ def get_tmp_expand_size(size: int) -> int: def get_tmp_tensors( nsegments: int, lora_rank: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: - if use_cutlass_shrink(lora_rank) and has_sgmv(): + use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv() + has_sgmv_available = has_sgmv() + + if use_cutlass: tmp = get_tmp_tensor_for_size(nsegments, device) return tmp, tmp + elif has_sgmv_available: + return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device) else: - tmp_shrink = get_tmp_tensor(device) - tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device) - return tmp_shrink, tmp_expand + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp def lora_a_sgmv_cutlass( From 96e8fa37b0a87ab3c522282aa9e7f9e8d5c8ecbc Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 11:26:11 -0400 Subject: [PATCH 219/340] fix: improve completions to send a final chunk with usage details (#2336) * fix: improve completions to send a final chunk with usage details * fix: include finish reason string * fix: remove dev debug trait and unneeded mut * fix: update openapi schema --- docs/openapi.json | 9 ++++++++- router/src/lib.rs | 2 ++ router/src/server.rs | 42 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 3a4c76bd..d0f78061 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1824,7 +1824,8 @@ "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "input_length" ], "properties": { "finish_reason": { @@ -1836,6 +1837,12 @@ "example": 1, "minimum": 0 }, + "input_length": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, "seed": { "type": "integer", "format": "int64", diff --git a/router/src/lib.rs b/router/src/lib.rs index 0a15c495..d7eb4475 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1219,6 +1219,8 @@ pub(crate) struct StreamDetails { pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, + #[schema(example = 1)] + pub input_length: u32, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 99ec077f..ab268efa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -533,7 +533,7 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, response_stream)) => { + Ok((_permit, input_length, response_stream)) => { let mut index = 0; let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream @@ -576,6 +576,7 @@ async fn generate_stream_internal( finish_reason: generated_text.finish_reason, generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, + input_length, }), false => None, }; @@ -801,21 +802,46 @@ async fn completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - event - .json_data(Completion::Chunk(Chunk { - id: "".to_string(), - created: current_time, + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, choices: vec![CompletionComplete { - finish_reason: "".to_string(), + finish_reason: String::new(), index: index as u32, logprobs: None, text: stream_token.token.text, }], - model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - })) + }), + }; + + event + .json_data(message) .unwrap_or_else(|_e| Event::default()) }; From 18d6be6af4bacb0285ad22c123679ce0adc90e34 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 18:09:16 +0200 Subject: [PATCH 220/340] Updating the flake. (#2404) --- _server.nix | 12 ++++ flake.lock | 158 +++++++++++++++++++++++++++++++++++++++++++++++----- flake.nix | 10 ++++ 3 files changed, 165 insertions(+), 15 deletions(-) create mode 100644 _server.nix diff --git a/_server.nix b/_server.nix new file mode 100644 index 00000000..bc11ba6b --- /dev/null +++ b/_server.nix @@ -0,0 +1,12 @@ +{ mkPoetryApplication, pkg-config, protobuf, openssl }: + +mkPoetryApplication { + # name = "text-generation-server"; + + projectDir = ./server; + + # nativeBuildInputs = [ pkg-config ]; + + # buildInputs = [ openssl.dev protobuf ]; + +} diff --git a/flake.lock b/flake.lock index 7889e7cf..47f14626 100644 --- a/flake.lock +++ b/flake.lock @@ -33,19 +33,96 @@ "type": "github" } }, - "nixpkgs": { + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, "locked": { - "lastModified": 1723099294, - "narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=", - "owner": "danieldk", - "repo": "nixpkgs", - "rev": "45892b6ec142eaf300d777926983a433b5842c88", + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "type": "github" }, "original": { - "owner": "danieldk", - "ref": "cudnn-9.3", + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-github-actions": { + "inputs": { + "nixpkgs": [ + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1703863825, + "narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=", + "owner": "nix-community", + "repo": "nix-github-actions", + "rev": "5163432afc817cf8bd1f031418d1869e4c9d5547", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nix-github-actions", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1719763542, + "narHash": "sha256-mXkOj9sJ0f69Nkc2dGGOWtof9d1YNY8Le/Hia3RN+8Q=", + "owner": "NixOS", "repo": "nixpkgs", + "rev": "e6cdd8a11b26b4d60593733106042141756b54a3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1723418128, + "narHash": "sha256-k1pEqsnB6ikZyasXbtV6A9akPZMKlsyENPDUA6PXoJo=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "129f579cbb5b4c1ad258fd96bdfb78eb14802727", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "poetry2nix": { + "inputs": { + "flake-utils": "flake-utils_2", + "nix-github-actions": "nix-github-actions", + "nixpkgs": "nixpkgs", + "systems": "systems_3", + "treefmt-nix": "treefmt-nix" + }, + "locked": { + "lastModified": 1723343306, + "narHash": "sha256-/6sRkPq7/5weX2y0V8sQ29Sz35nt8kyj+BsFtkhgbJE=", + "owner": "nix-community", + "repo": "poetry2nix", + "rev": "4a1c112ff0c67f496573dc345bd0b2247818fc29", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "poetry2nix", "type": "github" } }, @@ -56,6 +133,7 @@ "tgi-nix", "nixpkgs" ], + "poetry2nix": "poetry2nix", "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } @@ -68,11 +146,11 @@ ] }, "locked": { - "lastModified": 1723170066, - "narHash": "sha256-SFkQfOA+8AIYJsPlQtxNP+z5jRLfz91z/aOrV94pPmw=", + "lastModified": 1723429325, + "narHash": "sha256-4x/32xTCd+xCwFoI/kKSiCr5LQA2ZlyTRYXKEni5HR8=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "fecfe4d7c96fea2982c7907997b387a6b52c1093", + "rev": "65e3dc0fe079fe8df087cd38f1fe6836a0373aad", "type": "github" }, "original": { @@ -96,17 +174,46 @@ "type": "github" } }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "id": "systems", + "type": "indirect" + } + }, "tgi-nix": { "inputs": { "flake-compat": "flake-compat", - "nixpkgs": "nixpkgs" + "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1723234585, - "narHash": "sha256-HChJpNP155FPhHr9C5BtqllV8Uv/Ebg59HhMc/HhQrc=", + "lastModified": 1723450799, + "narHash": "sha256-cuT/ce7R2D5Lx6Ted4YS4y+WrAAOXQFAbzLsM1vtPo8=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "15bd4a978d6c2c8b04b7afc335d137dbe41e73df", + "rev": "29f9b45bd613eced65c4a5241a00aa9346f63d90", "type": "github" }, "original": { @@ -114,6 +221,27 @@ "repo": "tgi-nix", "type": "github" } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719749022, + "narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 761c4af8..1f842e8e 100644 --- a/flake.nix +++ b/flake.nix @@ -3,6 +3,7 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; + poetry2nix.url = "github:nix-community/poetry2nix"; rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; @@ -15,6 +16,7 @@ flake-utils, rust-overlay, tgi-nix, + poetry2nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -30,6 +32,11 @@ tgi-nix.overlay ]; }; + + inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; + text-generation-server = mkPoetryEditablePackage { + editablePackageSources = ./server; + }; in { devShells.default = @@ -50,14 +57,17 @@ venvShellHook pip + click einops fbgemm-gpu + flashinfer flash-attn flash-attn-layer-norm flash-attn-rotary grpc-interceptor grpcio-reflection grpcio-status + grpcio-tools hf-transfer loguru marlin-kernels From 1f8c0f83e357c98031bb8c643ba55f69da2af7e4 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 14:38:59 -0400 Subject: [PATCH 221/340] Pr 2395 ci run (#2406) * fix(router): Fix appending to message content * feat: add message and chat template test --------- Co-authored-by: Simone Rossi --- router/src/infer/chat_template.rs | 40 ++++++++++++++++++++++++++++++- router/src/lib.rs | 35 +++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index ef4beee2..a8537818 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -94,7 +94,9 @@ impl ChatTemplate { mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; - use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken}; + use crate::{ + ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + }; use minijinja::Environment; #[test] @@ -823,4 +825,40 @@ mod tests { } } } + + #[test] + fn test_chat_template_with_default_tool_template() { + let ct = ChatTemplate::new( + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText("Great! How can I help you today?".to_string()), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Just testing".to_string()), + }, + ]; + let tools = serde_json::json!("[]"); + let tool_prompt = "This default prompt will be used".to_string(); + let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); + let result = ct.apply(None, msgs, Some(grammer_with_prompt)); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + assert_eq!(result.unwrap(), expected); + } } diff --git a/router/src/lib.rs b/router/src/lib.rs index d7eb4475..1b2ff153 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1016,8 +1016,10 @@ impl MessageContent { pub fn push(&mut self, chunk: MessageChunk) { match self { MessageContent::SingleText(text) => { - *self = - MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); + *self = MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: text.clone() }, + chunk, + ]); } MessageContent::MultipleChunks(chunks) => { chunks.push(chunk); @@ -1348,6 +1350,35 @@ mod tests { ); } + #[test] + fn test_message_content_append() { + let mut content = MessageContent::SingleText("Initial text".to_string()); + let chunk = MessageChunk::Text { + text: "Additional text".to_string(), + }; + + content.push(chunk); + + match content { + MessageContent::MultipleChunks(chunks) => { + assert_eq!(chunks.len(), 2); + assert_eq!( + chunks[0], + MessageChunk::Text { + text: "Initial text".to_string() + } + ); + assert_eq!( + chunks[1], + MessageChunk::Text { + text: "Additional text".to_string() + } + ); + } + _ => panic!("Expected MultipleChunks, but got a different variant"), + } + } + #[test] fn test_chat_request() { let json = json!({ From 10b2be6536fe6411068e4a752d559a14908b8ff5 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 17:59:37 -0400 Subject: [PATCH 222/340] fix: include create_exllama_buffers and set_device for exllama (#2407) --- server/text_generation_server/layers/gptq/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 9c9b69d1..505caa59 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -422,12 +422,16 @@ elif CAN_EXLLAMA: if V2: from text_generation_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, # noqa: F401 + create_exllama_buffers, # noqa: F401 + set_device, # noqa: F401 ) HAS_EXLLAMA = "2" else: from text_generation_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + create_exllama_buffers, # noqa: F401 + set_device, # noqa: F401 ) HAS_EXLLAMA = "1" From eb561bb7150e0162c81a522c3d67bef3b8d473b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 13 Aug 2024 10:44:15 +0200 Subject: [PATCH 223/340] nix: incremental build of the launcher (#2410) --- _launcher.nix | 18 -- flake.lock | 844 +++++++++++++++++++++++++++++++++++++++++++++++++- flake.nix | 14 +- 3 files changed, 846 insertions(+), 30 deletions(-) delete mode 100644 _launcher.nix diff --git a/_launcher.nix b/_launcher.nix deleted file mode 100644 index 1acae7e1..00000000 --- a/_launcher.nix +++ /dev/null @@ -1,18 +0,0 @@ -{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: - -buildRustPackage { - name = "text-generation-lancher"; - - src = ./.; - - sourceDir = ./launcher; - - cargoLock = { - lockFile = ./Cargo.lock; - }; - - nativeBuildInputs = [ pkg-config ]; - - buildInputs = [ openssl.dev protobuf ]; - -} diff --git a/flake.lock b/flake.lock index 47f14626..bc114032 100644 --- a/flake.lock +++ b/flake.lock @@ -1,6 +1,309 @@ { "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "crate2nix" + ], + "flake-compat": [ + "crate2nix" + ], + "nixpkgs": "nixpkgs", + "pre-commit-hooks": [ + "crate2nix" + ] + }, + "locked": { + "lastModified": 1709700175, + "narHash": "sha256-A0/6ZjLmT9qdYzKHmevnEIC7G+GiZ4UCr8v0poRPzds=", + "owner": "cachix", + "repo": "cachix", + "rev": "be97b37989f11b724197b5f4c7ffd78f12c8c4bf", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "cachix_2": { + "inputs": { + "devenv": [ + "crate2nix", + "crate2nix_stable" + ], + "flake-compat": [ + "crate2nix", + "crate2nix_stable" + ], + "nixpkgs": "nixpkgs_2", + "pre-commit-hooks": [ + "crate2nix", + "crate2nix_stable" + ] + }, + "locked": { + "lastModified": 1716549461, + "narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=", + "owner": "cachix", + "repo": "cachix", + "rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "cachix_3": { + "inputs": { + "devenv": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable" + ], + "flake-compat": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable" + ], + "nixpkgs": "nixpkgs_3", + "pre-commit-hooks": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable" + ] + }, + "locked": { + "lastModified": 1716549461, + "narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=", + "owner": "cachix", + "repo": "cachix", + "rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "crate2nix": { + "inputs": { + "cachix": "cachix", + "crate2nix_stable": "crate2nix_stable", + "devshell": "devshell_3", + "flake-compat": "flake-compat_3", + "flake-parts": "flake-parts_3", + "nix-test-runner": "nix-test-runner_3", + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ], + "pre-commit-hooks": "pre-commit-hooks_3" + }, + "locked": { + "lastModified": 1723311214, + "narHash": "sha256-xdGZQBEa1AC2us/sY3igS/CucWY6jErXsAvCFRhB2LI=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "236f6addfd452a48be805819e3216af79e988fd5", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "crate2nix", + "type": "github" + } + }, + "crate2nix_stable": { + "inputs": { + "cachix": "cachix_2", + "crate2nix_stable": "crate2nix_stable_2", + "devshell": "devshell_2", + "flake-compat": "flake-compat_2", + "flake-parts": "flake-parts_2", + "nix-test-runner": "nix-test-runner_2", + "nixpkgs": "nixpkgs_5", + "pre-commit-hooks": "pre-commit-hooks_2" + }, + "locked": { + "lastModified": 1719760004, + "narHash": "sha256-esWhRnt7FhiYq0CcIxw9pvH+ybOQmWBfHYMtleaMhBE=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "1dee214bb20855fa3e1e7bb98d28922ddaff8c57", + "type": "github" + }, + "original": { + "owner": "nix-community", + "ref": "0.14.1", + "repo": "crate2nix", + "type": "github" + } + }, + "crate2nix_stable_2": { + "inputs": { + "cachix": "cachix_3", + "crate2nix_stable": "crate2nix_stable_3", + "devshell": "devshell", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "nix-test-runner": "nix-test-runner", + "nixpkgs": "nixpkgs_4", + "pre-commit-hooks": "pre-commit-hooks" + }, + "locked": { + "lastModified": 1712821484, + "narHash": "sha256-rGT3CW64cJS9nlnWPFWSc1iEa3dNZecVVuPVGzcsHe8=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "42883afcad3823fa5811e967fb7bff54bc3c9d6d", + "type": "github" + }, + "original": { + "owner": "nix-community", + "ref": "0.14.0", + "repo": "crate2nix", + "type": "github" + } + }, + "crate2nix_stable_3": { + "inputs": { + "flake-utils": "flake-utils" + }, + "locked": { + "lastModified": 1702842982, + "narHash": "sha256-A9AowkHIjsy1a4LuiPiVP88FMxyCWK41flZEZOUuwQM=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "75ac2973affa6b9b4f661a7b592cba6e4f51d426", + "type": "github" + }, + "original": { + "owner": "nix-community", + "ref": "0.12.0", + "repo": "crate2nix", + "type": "github" + } + }, + "devshell": { + "inputs": { + "flake-utils": "flake-utils_2", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1717408969, + "narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=", + "owner": "numtide", + "repo": "devshell", + "rev": "1ebbe68d57457c8cae98145410b164b5477761f4", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "devshell", + "type": "github" + } + }, + "devshell_2": { + "inputs": { + "flake-utils": "flake-utils_3", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1717408969, + "narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=", + "owner": "numtide", + "repo": "devshell", + "rev": "1ebbe68d57457c8cae98145410b164b5477761f4", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "devshell", + "type": "github" + } + }, + "devshell_3": { + "inputs": { + "flake-utils": "flake-utils_4", + "nixpkgs": [ + "crate2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1711099426, + "narHash": "sha256-HzpgM/wc3aqpnHJJ2oDqPBkNsqWbW0WfWUO8lKu8nGk=", + "owner": "numtide", + "repo": "devshell", + "rev": "2d45b54ca4a183f2fdcf4b19c895b64fbf620ee8", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "devshell", + "type": "github" + } + }, "flake-compat": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_3": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_4": { "locked": { "lastModified": 1696426674, "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", @@ -15,10 +318,148 @@ "type": "github" } }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719745305, + "narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_2": { + "inputs": { + "nixpkgs-lib": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719745305, + "narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_3": { + "inputs": { + "nixpkgs-lib": [ + "crate2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1712014858, + "narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "9126214d0a59633752a136528f5f3b9aa8565b7d", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, "flake-utils": { "inputs": { "systems": "systems" }, + "locked": { + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_3": { + "inputs": { + "systems": "systems_3" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_4": { + "inputs": { + "systems": "systems_4" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_5": { + "inputs": { + "systems": "systems_5" + }, "locked": { "lastModified": 1710146030, "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", @@ -33,9 +474,9 @@ "type": "github" } }, - "flake-utils_2": { + "flake-utils_6": { "inputs": { - "systems": "systems_2" + "systems": "systems_6" }, "locked": { "lastModified": 1710146030, @@ -51,6 +492,93 @@ "type": "github" } }, + "flake-utils_7": { + "inputs": { + "systems": "systems_7" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "gitignore_2": { + "inputs": { + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "gitignore_3": { + "inputs": { + "nixpkgs": [ + "crate2nix", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -72,7 +600,129 @@ "type": "github" } }, + "nix-test-runner": { + "flake": false, + "locked": { + "lastModified": 1588761593, + "narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=", + "owner": "stoeffel", + "repo": "nix-test-runner", + "rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2", + "type": "github" + }, + "original": { + "owner": "stoeffel", + "repo": "nix-test-runner", + "type": "github" + } + }, + "nix-test-runner_2": { + "flake": false, + "locked": { + "lastModified": 1588761593, + "narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=", + "owner": "stoeffel", + "repo": "nix-test-runner", + "rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2", + "type": "github" + }, + "original": { + "owner": "stoeffel", + "repo": "nix-test-runner", + "type": "github" + } + }, + "nix-test-runner_3": { + "flake": false, + "locked": { + "lastModified": 1588761593, + "narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=", + "owner": "stoeffel", + "repo": "nix-test-runner", + "rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2", + "type": "github" + }, + "original": { + "owner": "stoeffel", + "repo": "nix-test-runner", + "type": "github" + } + }, "nixpkgs": { + "locked": { + "lastModified": 1700612854, + "narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1715534503, + "narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2057814051972fa1453ddfb0d98badbea9b83c06", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_3": { + "locked": { + "lastModified": 1715534503, + "narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2057814051972fa1453ddfb0d98badbea9b83c06", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_4": { + "locked": { + "lastModified": 1719506693, + "narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=", + "path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source", + "rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a", + "type": "path" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "nixpkgs_5": { + "locked": { + "lastModified": 1719506693, + "narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=", + "path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source", + "rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a", + "type": "path" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "nixpkgs_6": { "locked": { "lastModified": 1719763542, "narHash": "sha256-mXkOj9sJ0f69Nkc2dGGOWtof9d1YNY8Le/Hia3RN+8Q=", @@ -88,7 +738,7 @@ "type": "github" } }, - "nixpkgs_2": { + "nixpkgs_7": { "locked": { "lastModified": 1723418128, "narHash": "sha256-k1pEqsnB6ikZyasXbtV6A9akPZMKlsyENPDUA6PXoJo=", @@ -106,10 +756,10 @@ }, "poetry2nix": { "inputs": { - "flake-utils": "flake-utils_2", + "flake-utils": "flake-utils_7", "nix-github-actions": "nix-github-actions", - "nixpkgs": "nixpkgs", - "systems": "systems_3", + "nixpkgs": "nixpkgs_6", + "systems": "systems_8", "treefmt-nix": "treefmt-nix" }, "locked": { @@ -126,9 +776,110 @@ "type": "github" } }, + "pre-commit-hooks": { + "inputs": { + "flake-compat": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ], + "nixpkgs-stable": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719259945, + "narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "pre-commit-hooks_2": { + "inputs": { + "flake-compat": [ + "crate2nix", + "crate2nix_stable", + "flake-compat" + ], + "gitignore": "gitignore_2", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ], + "nixpkgs-stable": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719259945, + "narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "pre-commit-hooks_3": { + "inputs": { + "flake-compat": [ + "crate2nix", + "flake-compat" + ], + "flake-utils": "flake-utils_5", + "gitignore": "gitignore_3", + "nixpkgs": [ + "crate2nix", + "nixpkgs" + ], + "nixpkgs-stable": [ + "crate2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1712055707, + "narHash": "sha256-4XLvuSIDZJGS17xEwSrNuJLL7UjDYKGJSbK1WWX2AK8=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "e35aed5fda3cc79f88ed7f1795021e559582093a", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, "root": { "inputs": { - "flake-utils": "flake-utils", + "crate2nix": "crate2nix", + "flake-utils": "flake-utils_6", "nixpkgs": [ "tgi-nix", "nixpkgs" @@ -190,6 +941,81 @@ } }, "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_4": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_5": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_6": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_7": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_8": { "locked": { "lastModified": 1681028828, "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", @@ -205,8 +1031,8 @@ }, "tgi-nix": { "inputs": { - "flake-compat": "flake-compat", - "nixpkgs": "nixpkgs_2" + "flake-compat": "flake-compat_4", + "nixpkgs": "nixpkgs_7" }, "locked": { "lastModified": 1723450799, diff --git a/flake.nix b/flake.nix index 1f842e8e..40b845b7 100644 --- a/flake.nix +++ b/flake.nix @@ -1,5 +1,9 @@ { inputs = { + crate2nix = { + url = "github:nix-community/crate2nix"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; @@ -12,6 +16,7 @@ outputs = { self, + crate2nix, nixpkgs, flake-utils, rust-overlay, @@ -21,6 +26,10 @@ flake-utils.lib.eachDefaultSystem ( system: let + cargoNix = crate2nix.tools.${system}.appliedCargoNix { + name = "tgi"; + src = ./.; + }; config = { allowUnfree = true; cudaSupport = true; @@ -81,12 +90,11 @@ transformers vllm + cargoNix.workspaceMembers.text-generation-launcher.build + (callPackage ./router.nix { inherit (rustPlatform) buildRustPackage importCargoLock; }) - (callPackage ./_launcher.nix { - inherit (rustPlatform) buildRustPackage importCargoLock; - }) ]); venvDir = "./.venv"; From c5e4c1877b80e7fcb2a2efd4688f40ffb49cdf34 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 13 Aug 2024 10:49:18 +0200 Subject: [PATCH 224/340] Adding more kernels to flake. (#2411) --- flake.lock | 18 +++++++++--------- flake.nix | 2 ++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/flake.lock b/flake.lock index bc114032..ef05c258 100644 --- a/flake.lock +++ b/flake.lock @@ -763,11 +763,11 @@ "treefmt-nix": "treefmt-nix" }, "locked": { - "lastModified": 1723343306, - "narHash": "sha256-/6sRkPq7/5weX2y0V8sQ29Sz35nt8kyj+BsFtkhgbJE=", + "lastModified": 1723512448, + "narHash": "sha256-VSTtxGKre1p6zd6ACuBmfDcR+BT9+ml8Y3KrSbfGFYU=", "owner": "nix-community", "repo": "poetry2nix", - "rev": "4a1c112ff0c67f496573dc345bd0b2247818fc29", + "rev": "ed52f844c4dd04dde45550c3189529854384124e", "type": "github" }, "original": { @@ -897,11 +897,11 @@ ] }, "locked": { - "lastModified": 1723429325, - "narHash": "sha256-4x/32xTCd+xCwFoI/kKSiCr5LQA2ZlyTRYXKEni5HR8=", + "lastModified": 1723515680, + "narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "65e3dc0fe079fe8df087cd38f1fe6836a0373aad", + "rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", "type": "github" }, "original": { @@ -1035,11 +1035,11 @@ "nixpkgs": "nixpkgs_7" }, "locked": { - "lastModified": 1723450799, - "narHash": "sha256-cuT/ce7R2D5Lx6Ted4YS4y+WrAAOXQFAbzLsM1vtPo8=", + "lastModified": 1723532088, + "narHash": "sha256-6h/P/BkFDw8txlikonKXp5IbluHSPhHJTQRftJLkbLQ=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "29f9b45bd613eced65c4a5241a00aa9346f63d90", + "rev": "32335a37ce0f703bab901baf7b74eb11e9972d5f", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 40b845b7..8784be37 100644 --- a/flake.nix +++ b/flake.nix @@ -66,6 +66,7 @@ venvShellHook pip + causal-conv1d click einops fbgemm-gpu @@ -79,6 +80,7 @@ grpcio-tools hf-transfer loguru + mamba-ssm marlin-kernels opentelemetry-api opentelemetry-exporter-otlp From 7a4d831d179783f7ec33d76d2f8fe78c92eff9df Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 13 Aug 2024 21:33:55 +0800 Subject: [PATCH 225/340] add numa to improve cpu inference perf (#2330) Signed-off-by: Wang, Yi A --- Dockerfile_intel | 12 +++---- .../models/flash_causal_lm.py | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 158c5a89..12480c70 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -106,7 +106,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins g++ \ git \ wget \ - cmake + cmake \ + libnuma-dev ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ @@ -135,7 +136,7 @@ RUN conda install -c conda-forge gperftools mkl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install triton +RUN pip install triton numa WORKDIR /usr/src @@ -147,16 +148,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . -ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib -ENV KMP_BLOCKTIME=1 -ENV KMP_TPAUSE=0 -ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist -ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist -ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist # Install server COPY proto proto diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 21b66a68..42d93a12 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -74,6 +74,36 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW +def init_cpu_threads_env(rank_id: int, world_size: int): + import importlib.util + + if importlib.util.find_spec("numa") is not None: + import numa + import psutil + + nodes = numa.get_max_node() + 1 + rank_per_node = math.ceil(world_size / nodes) + num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) + node_id = int(rank_id / rank_per_node) + rank_offset_per_node = rank_id % rank_per_node + if os.getenv("OMP_NUM_THREADS") is None: + num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) + else: + num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) + if len(numa.get_membind()) == nodes: + numa.set_membind([node_id]) + torch.set_num_threads(num_cpus_per_rank) + if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): + cpu_start = num_cpus_per_rank * rank_offset_per_node + numa.set_affinity( + 0, + list(numa.node_to_cpus(node_id))[ + cpu_start : cpu_start + num_cpus_per_rank + ], + ) + logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") + + @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -854,6 +884,7 @@ class FlashCausalLM(Model): device = torch.device("cpu") # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype + init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU") From ffc8fb08503e3c0782c14f5bc3f0cc3ea02bc4e2 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 13 Aug 2024 10:19:46 -0400 Subject: [PATCH 226/340] fix: adds causal to attention params (#2408) fix: adds causal to attention params to check when using flash attn v1 --- server/text_generation_server/layers/attention/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index d039e1e7..8703eb94 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -293,6 +293,7 @@ else: max_s, softmax_scale, window_size_left=-1, + causal=None, softcap=None, ): if window_size_left != -1: From bae161ab842290a921a2c72a945ab2f017c2d295 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 14 Aug 2024 11:06:28 +0200 Subject: [PATCH 227/340] nix: partial incremental build of the router (#2416) This is less incremental than crate2nix, but does build all dependencies separately, so avoids full rebuilds. --- _server.nix | 7 ++++++- flake.lock | 22 ++++++++++++++++++++++ flake.nix | 37 +++++++++++++++++++++++++++++-------- router.nix | 18 ------------------ 4 files changed, 57 insertions(+), 27 deletions(-) delete mode 100644 router.nix diff --git a/_server.nix b/_server.nix index bc11ba6b..2cb2f887 100644 --- a/_server.nix +++ b/_server.nix @@ -1,4 +1,9 @@ -{ mkPoetryApplication, pkg-config, protobuf, openssl }: +{ + mkPoetryApplication, + pkg-config, + protobuf, + openssl, +}: mkPoetryApplication { # name = "text-generation-server"; diff --git a/flake.lock b/flake.lock index ef05c258..dab582a5 100644 --- a/flake.lock +++ b/flake.lock @@ -579,6 +579,27 @@ "type": "github" } }, + "naersk": { + "inputs": { + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1721727458, + "narHash": "sha256-r/xppY958gmZ4oTfLiHN0ZGuQ+RSTijDblVgVLFi1mw=", + "owner": "nix-community", + "repo": "naersk", + "rev": "3fb418eaf352498f6b6c30592e3beb63df42ef11", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "naersk", + "type": "github" + } + }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -880,6 +901,7 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", + "naersk": "naersk", "nixpkgs": [ "tgi-nix", "nixpkgs" diff --git a/flake.nix b/flake.nix index 8784be37..e1f44212 100644 --- a/flake.nix +++ b/flake.nix @@ -7,6 +7,10 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; + naersk = { + url = "github:nix-community/naersk"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; poetry2nix.url = "github:nix-community/poetry2nix"; rust-overlay = { url = "github:oxalica/rust-overlay"; @@ -17,6 +21,7 @@ { self, crate2nix, + naersk, nixpkgs, flake-utils, rust-overlay, @@ -41,13 +46,32 @@ tgi-nix.overlay ]; }; + naersk' = pkgs.callPackage naersk { }; + router = + with pkgs; + naersk'.buildPackage { + name = "router"; + src = ./.; + cargoBuildOptions = + x: + x + ++ [ + "-p" + "text-generation-router-v3" + ]; + nativeBuildInputs = [ pkg-config ]; + buildInputs = [ + openssl.dev + protobuf + ]; + doCheck = false; + }; inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; - text-generation-server = mkPoetryEditablePackage { - editablePackageSources = ./server; - }; + text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; in { + defaultPackage = router; devShells.default = with pkgs; mkShell { @@ -92,11 +116,8 @@ transformers vllm - cargoNix.workspaceMembers.text-generation-launcher.build - - (callPackage ./router.nix { - inherit (rustPlatform) buildRustPackage importCargoLock; - }) + cargoNix.workspaceMembers.text-generation-launcher.build + router ]); venvDir = "./.venv"; diff --git a/router.nix b/router.nix deleted file mode 100644 index eeeac199..00000000 --- a/router.nix +++ /dev/null @@ -1,18 +0,0 @@ -{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: - -buildRustPackage { - name = "text-generation-router"; - - src = ./.; - - sourceDir = ./backends/v3; - - cargoLock = { - lockFile = ./Cargo.lock; - }; - - nativeBuildInputs = [ pkg-config ]; - - buildInputs = [ openssl.dev protobuf ]; - -} From 4baa6ff59f45cda2bbdc30289cfc4357a0b8b426 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Aug 2024 11:58:08 +0200 Subject: [PATCH 228/340] Upgrading exl2. (#2415) * Upgrading exl2. * Fixing the other pathways. * Fix idefics. --- .gitignore | 2 +- flake.nix | 1 + server/Makefile | 1 + server/Makefile-exllamav2 | 12 ++++++++++++ server/text_generation_server/models/causal_lm.py | 1 + .../text_generation_server/models/flash_causal_lm.py | 1 + server/text_generation_server/models/idefics.py | 1 + .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/seq2seq_lm.py | 1 + server/text_generation_server/server.py | 6 +++--- 10 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 server/Makefile-exllamav2 diff --git a/.gitignore b/.gitignore index bd9d9125..f79d8faa 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ backends/client/src/v3/pb # ROCm auto-generated files *.hip -server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllamav2 server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip_func/ *_hip.cuh diff --git a/flake.nix b/flake.nix index e1f44212..229184d2 100644 --- a/flake.nix +++ b/flake.nix @@ -93,6 +93,7 @@ causal-conv1d click einops + exllamav2 fbgemm-gpu flashinfer flash-attn diff --git a/server/Makefile b/server/Makefile index 209fc44e..51ea8b32 100644 --- a/server/Makefile +++ b/server/Makefile @@ -6,6 +6,7 @@ include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica include Makefile-fbgemm +include Makefile-exllamav2 unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-exllamav2 b/server/Makefile-exllamav2 new file mode 100644 index 00000000..0d4cc385 --- /dev/null +++ b/server/Makefile-exllamav2 @@ -0,0 +1,12 @@ +exllamav2_commit := v0.1.8 + +build-exllamav2: + git clone https://github.com/turboderp/exllamav2.git exllamav2 && \ + cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build + +install-exllamav2: build-exllamav2 + cd exllamav2/ && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 212ab7a9..ba168b13 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -511,6 +511,7 @@ class CausalLM(Model): config_class=AutoConfig, batch_class=CausalLMBatch, ): + self.quantize = quantize self.batch_class = batch_class self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42d93a12..5e2fd20a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -872,6 +872,7 @@ class FlashCausalLM(Model): head_size: Optional[int] = None, skip_special_tokens: bool = True, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 29929b98..9058cb96 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -33,6 +33,7 @@ class IDEFICSSharded(IdeficsCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 8a80ed68..c5480952 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -580,6 +580,7 @@ class IdeficsCausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 79c001b0..3c92128a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -553,6 +553,7 @@ class Seq2SeqLM(Model): tokenizer_class=AutoTokenizer, aliases=None, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b92ab572..22871ec5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -50,12 +50,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self, model: Model, cache: Cache, - quantize: Optional[str], server_urls: List[str], ): self.cache = cache self.model = model - self.quantize = quantize + # Quantize is resolved during model loading + self.quantize = model.quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU if model.device.type == "cuda": @@ -255,7 +255,7 @@ def serve( ], ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( - TextGenerationService(model, Cache(), quantize, server_urls), server + TextGenerationService(model, Cache(), server_urls), server ) SERVICE_NAMES = ( generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, From c3401e0b994613690d9484177573761ade842749 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Wed, 14 Aug 2024 12:02:05 +0200 Subject: [PATCH 229/340] More fixes trtllm (#2342) * (backend) use parking_lot crate for RwLock fairness * (docker) let's put rust in the TRTLLM folder when building * (docker) build ompi with SLURM support * (launcher) default new server::run parameters to false for now * (chore) fmt ... why? --- backends/trtllm/Cargo.toml | 5 +++-- backends/trtllm/Dockerfile | 9 +++++---- backends/trtllm/src/backend.rs | 3 ++- backends/trtllm/src/main.rs | 8 ++++---- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 7079d3d1..43a114ba 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -8,17 +8,18 @@ homepage.workspace = true [dependencies] async-trait = "0.1" async-stream = "0.3" +clap = { version = "4.5", features = ["derive"] } cxx = "1.0" +log = { version = "0.4", features = [] } text-generation-router = { path = "../../router" } tokenizers = { version = "0.19", features = ["hf-hub"] } tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.15" -clap = { version = "4.5", features = ["derive"] } thiserror = "1.0.62" tracing = "0.1" tracing-opentelemetry = "0.24" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -log = { version = "0.4", features = [] } +parking_lot = "0.12" [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile index 60ad03f7..5fd2f89f 100644 --- a/backends/trtllm/Dockerfile +++ b/backends/trtllm/Dockerfile @@ -3,7 +3,7 @@ ARG OMPI_VERSION="4.1.6" # Build dependencies resolver stage FROM lukemathwalker/cargo-chef:latest AS chef -WORKDIR /usr/src/text-generation-inference +WORKDIR /usr/src/text-generation-inference/backends/trtllm FROM chef AS planner COPY . . @@ -42,7 +42,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE mkdir /usr/src/mpi && \ tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ cd /usr/src/mpi && \ - ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \ make -j all && \ make install && \ rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" @@ -66,7 +66,7 @@ ENV PATH="/root/.cargo/bin:$PATH" RUN cargo install cargo-chef # Cache dependencies -COPY --from=planner /usr/src/text-generation-inference/recipe.json . +COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json . RUN cargo chef cook --release --recipe-path recipe.json # Build actual TGI @@ -79,7 +79,8 @@ COPY . . COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ - CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + cd backends/trtllm && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime WORKDIR /usr/local/tgi/bin diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index b26d06a6..b23aa6c0 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -12,12 +12,13 @@ use cxx::UniquePtr; use log::{error, warn}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::sync::RwLock; use tokio::time::{sleep, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::{Stream, StreamExt}; use tracing::{instrument, span, Level}; +// use tokio::sync::RwLock; +use parking_lot::RwLock; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 6d6ee146..e0ba46c7 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,12 +1,10 @@ +use clap::Parser; use std::collections::HashMap; use std::path::PathBuf; - -use clap::Parser; -use tokenizers::{FromPretrainedParameters, Tokenizer}; - use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_router::server; +use tokenizers::{FromPretrainedParameters, Tokenizer}; /// App Configuration #[derive(Parser, Debug)] @@ -160,6 +158,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { messages_api_enabled, true, max_client_batch_size, + false, + false, ) .await?; Ok(()) From e5c39a55456e5c9a0b1579a308047fa7c084bc1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 15 Aug 2024 10:21:51 +0200 Subject: [PATCH 230/340] nix: build router incrementally (#2422) --- flake.lock | 22 ---------------------- flake.nix | 53 +++++++++++++++++++++++++---------------------------- 2 files changed, 25 insertions(+), 50 deletions(-) diff --git a/flake.lock b/flake.lock index dab582a5..ef05c258 100644 --- a/flake.lock +++ b/flake.lock @@ -579,27 +579,6 @@ "type": "github" } }, - "naersk": { - "inputs": { - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1721727458, - "narHash": "sha256-r/xppY958gmZ4oTfLiHN0ZGuQ+RSTijDblVgVLFi1mw=", - "owner": "nix-community", - "repo": "naersk", - "rev": "3fb418eaf352498f6b6c30592e3beb63df42ef11", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "naersk", - "type": "github" - } - }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -901,7 +880,6 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", - "naersk": "naersk", "nixpkgs": [ "tgi-nix", "nixpkgs" diff --git a/flake.nix b/flake.nix index 229184d2..168c8649 100644 --- a/flake.nix +++ b/flake.nix @@ -7,10 +7,6 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; - naersk = { - url = "github:nix-community/naersk"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; - }; poetry2nix.url = "github:nix-community/poetry2nix"; rust-overlay = { url = "github:oxalica/rust-overlay"; @@ -21,7 +17,6 @@ { self, crate2nix, - naersk, nixpkgs, flake-utils, rust-overlay, @@ -34,6 +29,7 @@ cargoNix = crate2nix.tools.${system}.appliedCargoNix { name = "tgi"; src = ./.; + additionalCargoNixArgs = [ "--all-features" ]; }; config = { allowUnfree = true; @@ -46,32 +42,10 @@ tgi-nix.overlay ]; }; - naersk' = pkgs.callPackage naersk { }; - router = - with pkgs; - naersk'.buildPackage { - name = "router"; - src = ./.; - cargoBuildOptions = - x: - x - ++ [ - "-p" - "text-generation-router-v3" - ]; - nativeBuildInputs = [ pkg-config ]; - buildInputs = [ - openssl.dev - protobuf - ]; - doCheck = false; - }; - inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; in { - defaultPackage = router; devShells.default = with pkgs; mkShell { @@ -118,7 +92,30 @@ vllm cargoNix.workspaceMembers.text-generation-launcher.build - router + + (cargoNix.workspaceMembers.text-generation-router-v3.build.override { + crateOverrides = defaultCrateOverrides // { + aws-lc-rs = attrs: { + # aws-lc-rs does its own custom parsing of Cargo environment + # variables like DEP_.*_INCLUDE. However buildRustCrate does + # not use the version number, so the parsing fails. + postPatch = '' + substituteInPlace build.rs \ + --replace-fail \ + "assert!(!selected.is_empty()" \ + "// assert!(!selected.is_empty()" + ''; + }; + rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; + text-generation-router-v3 = attrs: { + # We need to do the src/source root dance so that the build + # has access to the protobuf file. + src = ./.; + postPatch = "cd backends/v3"; + buildInputs = [ protobuf ]; + }; + }; + }) ]); venvDir = "./.venv"; From df6ea89da96b515e7ce965560bead40d85709e6c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Aug 2024 11:12:51 +0200 Subject: [PATCH 231/340] Fixing exl2 and other quanize tests again. (#2419) * Fixing exl2 and other quanize tests again. * Mark exl2 as non release (so CI tests them, needs to be removed latet). * Fixing exl2 (by disabling cuda graphs) * Fix quantization defaults without cuda graphs on exl2 (linked to new issues with it). * Removing serde override. * Go back to released exl2 and remove log. * Adding warnings for deprecated bitsandbytes + upgrade info to warn. --- .../models/test_flash_llama_exl2.py | 3 - launcher/src/main.rs | 156 ++++++++++-------- server/poetry.lock | 69 +++++++- server/pyproject.toml | 1 + server/requirements_cuda.txt | 4 + server/requirements_intel.txt | 4 + server/requirements_rocm.txt | 4 + .../models/causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + .../models/seq2seq_lm.py | 1 + 10 files changed, 173 insertions(+), 71 deletions(-) diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 7169c999..18319f60 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle): return flash_llama_exl2_handle.client -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): @@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_all_params( @@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params( assert response == ignore_logprob_response_snapshot -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_load( diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 4c8f288a..e4505b35 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -32,11 +32,18 @@ struct RawConfig { n_positions: Option, model_type: Option, max_seq_len: Option, + quantization_config: Option, +} + +#[derive(Deserialize)] +struct QuantizationConfig { + quant_method: Option, } #[derive(Deserialize)] struct Config { max_position_embeddings: Option, + quantize: Option, } impl From for Config { @@ -45,13 +52,16 @@ impl From for Config { .max_position_embeddings .or(other.max_seq_len) .or(other.n_positions); + let quantize = other.quantization_config.and_then(|q| q.quant_method); Config { max_position_embeddings, + quantize, } } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)] +#[serde(rename_all = "kebab-case")] enum Quantization { /// 4 bit quantization. Requires a specific AWQ quantized model: /// . @@ -74,17 +84,17 @@ enum Quantization { Marlin, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// but it is known that the model will be much slower to run than the native f16. - #[deprecated( - since = "1.1.0", - note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" - )] + // #[deprecated( + // since = "1.1.0", + // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + // )] Bitsandbytes, /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, /// but it is known that the model will be much slower to run than the native f16. - BitsandbytesNF4, + BitsandbytesNf4, /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model - BitsandbytesFP4, + BitsandbytesFp4, /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above /// This dtype has native ops should be the fastest if available. /// This is currently not the fastest because of local unpacking + padding to satisfy matrix @@ -101,10 +111,10 @@ impl std::fmt::Display for Quantization { Quantization::Bitsandbytes => { write!(f, "bitsandbytes") } - Quantization::BitsandbytesNF4 => { + Quantization::BitsandbytesNf4 => { write!(f, "bitsandbytes-nf4") } - Quantization::BitsandbytesFP4 => { + Quantization::BitsandbytesFp4 => { write!(f, "bitsandbytes-fp4") } Quantization::Exl2 => { @@ -1089,6 +1099,7 @@ fn spawn_shards( cuda_graphs: Vec, max_total_tokens: usize, max_input_tokens: usize, + quantize: Option, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1110,7 +1121,6 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_service_name = args.otlp_service_name.clone(); - let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; @@ -1443,65 +1453,68 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_position_embeddings = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id + let get_max_positions_quantize = + || -> Result<(usize, Option), Box> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? } else { - Api::new()? + path.push("config.json"); + path }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - let config: Config = config.into(); - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok(max_default) - } else { - Ok(max_position_embeddings) + if config.model_type == Some("gemma2".to_string()) { + tracing::info!("Forcing flash decoding because of softcap usage"); + std::env::set_var("ATTENTION", "flashdecoding"); } - } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) - } - }; - let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); + let config: Config = config.into(); + let quantize = config.quantize; + + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + Ok((max_default, quantize)) + } else { + Ok((max_position_embeddings, quantize)) + } + } else { + Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))) + } + }; + let (max_position_embeddings, quantize): (usize, Option) = + get_max_positions_quantize().unwrap_or((4096, None)); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { @@ -1558,18 +1571,26 @@ fn main() -> Result<(), LauncherError> { ))); } - let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { + tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); + } + let quantize = args.quantize.or(quantize); + let cuda_graphs = match (&args.cuda_graphs, &quantize) { (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] ( None, Some( Quantization::Bitsandbytes - | Quantization::BitsandbytesNF4 - | Quantization::BitsandbytesFP4, + | Quantization::BitsandbytesNf4 + | Quantization::BitsandbytesFp4, ), ) => { - tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + vec![] + } + (None, Some(Quantization::Exl2)) => { + tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them"); vec![] } _ => { @@ -1686,6 +1707,7 @@ fn main() -> Result<(), LauncherError> { cuda_graphs, max_total_tokens, max_input_tokens, + quantize, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/poetry.lock b/server/poetry.lock index 5072aa0b..fc1a54a3 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1207,6 +1231,17 @@ torch = "*" type = "url" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -2277,6 +2312,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pygments" +version = "2.18.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pytest" version = "7.4.4" @@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.7.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, + {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.19.0" @@ -3584,4 +3651,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" +content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd" diff --git a/server/pyproject.toml b/server/pyproject.toml index da0d1526..ca54bee9 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -46,6 +46,7 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] +rich = "^13.7.1" [tool.poetry.extras] torch = ["torch"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba168b13..28534d0f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -652,6 +652,7 @@ class CausalLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 5d6ce3c7..f6dcde68 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -412,6 +412,7 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3c92128a..04d4c28b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -676,6 +676,7 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property From f0181ed2d78af10d5afd4236bce9f486eb4c863b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Aug 2024 13:28:42 +0200 Subject: [PATCH 232/340] Upgrading the tests to match the current workings. (#2423) --- .../test_bloom_560m/test_bloom_560m.json | 30 +- .../test_flash_deepseek_v2_all_params.json | 8 +- .../test_flash_gemma/test_flash_gemma.json | 32 +- .../test_flash_starcoder_gptq.json | 146 +------- ...t_flash_starcoder_gptq_default_params.json | 144 +------- .../test_flash_starcoder_gptq_load.json | 344 ++++-------------- .../test_mamba/test_mamba_all_params.json | 24 +- .../test_mt0_base_all_params.json | 6 +- .../models/test_flash_starcoder_gptq.py | 4 +- integration-tests/models/test_mamba.py | 2 +- 10 files changed, 166 insertions(+), 574 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json index b274992e..5d0eeef6 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json @@ -11,52 +11,52 @@ }, { "id": 49833, - "logprob": -10.5703125, + "logprob": -10.546875, "text": " dég" }, { "id": 21543, - "logprob": -0.14746094, + "logprob": -0.18457031, "text": "uster" }, { "id": 447, - "logprob": -1.9277344, + "logprob": -1.9287109, "text": " un" }, { "id": 46341, - "logprob": -15.421875, + "logprob": -15.4296875, "text": " ort" }, { "id": 35567, - "logprob": -7.5820312, + "logprob": -7.578125, "text": "olan" }, { "id": 15, - "logprob": -1.4013672, + "logprob": -1.4003906, "text": "," }, { "id": 1669, - "logprob": -1.5664062, + "logprob": -1.5439453, "text": " il" }, { "id": 11580, - "logprob": -0.94189453, + "logprob": -0.93896484, "text": " faut" }, { "id": 3913, - "logprob": -3.6816406, + "logprob": -3.7207031, "text": " tout" }, { "id": 39261, - "logprob": -1.7753906, + "logprob": -1.5742188, "text": " d'abord" } ], @@ -64,13 +64,13 @@ "tokens": [ { "id": 578, - "logprob": -1.6318359, + "logprob": -1.6474609, "special": false, "text": " le" }, { "id": 5608, - "logprob": -2.4882812, + "logprob": -2.4707031, "special": false, "text": " faire" }, @@ -88,19 +88,19 @@ }, { "id": 693, - "logprob": -2.4472656, + "logprob": -2.4628906, "special": false, "text": " à" }, { "id": 366, - "logprob": -1.1972656, + "logprob": -1.1953125, "special": false, "text": " la" }, { "id": 48844, - "logprob": -1.7890625, + "logprob": -1.7978516, "special": false, "text": " cass" }, diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json index 6b45cf6b..3ac8d050 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -11,7 +11,7 @@ }, { "id": 3533, - "logprob": -9.625, + "logprob": -9.5625, "text": "Test" }, { @@ -24,13 +24,13 @@ "tokens": [ { "id": 2143, - "logprob": -1.828125, + "logprob": -1.8203125, "special": false, "text": " sent" }, { "id": 10081, - "logprob": -0.41210938, + "logprob": -0.55078125, "special": false, "text": " successfully" }, @@ -42,7 +42,7 @@ }, { "id": 100001, - "logprob": -0.16015625, + "logprob": -0.12695312, "special": true, "text": "<|end▁of▁sentence|>" } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json index 8829f9fe..96f2ce17 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -24,13 +24,13 @@ "tokens": [ { "id": 1736, - "logprob": -2.03125, + "logprob": -2.046875, "special": false, "text": " form" }, { "id": 109, - "logprob": -1.8671875, + "logprob": -1.8828125, "special": false, "text": "\n\n" }, @@ -42,48 +42,48 @@ }, { "id": 2121, - "logprob": -1.8125, + "logprob": -1.78125, "special": false, "text": " test" }, { "id": 3853, - "logprob": -0.24121094, + "logprob": -0.23632812, "special": false, "text": " request" }, { "id": 1736, - "logprob": -0.100097656, + "logprob": -0.09326172, "special": false, "text": " form" }, { "id": 603, - "logprob": -0.9453125, + "logprob": -0.8828125, "special": false, "text": " is" }, { - "id": 476, - "logprob": -1.703125, + "id": 1671, + "logprob": -1.6171875, "special": false, - "text": " a" + "text": " used" }, { - "id": 4551, - "logprob": -2.453125, + "id": 577, + "logprob": -0.390625, "special": false, - "text": " document" + "text": " to" }, { - "id": 674, - "logprob": -0.796875, + "id": 3853, + "logprob": -1.2265625, "special": false, - "text": " that" + "text": " request" } ], "top_tokens": null }, - "generated_text": " form\n\nThe test request form is a document that" + "generated_text": " form\n\nThe test request form is used to request" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 5e537bb7..26224118 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5859375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.2668457, + "logprob": -0.21875, "text": "_" }, { "id": 6009, - "logprob": -1.6416016, + "logprob": -1.2773438, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25195312, "text": "(" }, { "id": 62, - "logprob": -5.2304688, + "logprob": -4.8203125, "text": "L" }, { "id": 44, - "logprob": -3.0976562, + "logprob": -3.7734375, "text": ":" }, { "id": 1682, - "logprob": -1.1044922, + "logprob": -0.8310547, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22766113, "text": "[" }, { "id": 1808, - "logprob": -0.32299805, + "logprob": -0.46240234, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0234375, "text": "]):" } ], @@ -69,126 +69,18 @@ "tokens": [ { "id": 284, - "logprob": -0.1282959, + "logprob": -0.04626465, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.97998047, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.7006836, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.1933594, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2697754, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.0836792, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.018737793, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.028640747, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29467773, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.31518555, - "special": false, - "text": " a" - }, - { - "id": 1149, - "logprob": -0.20605469, - "special": false, - "text": " list" - }, - { - "id": 432, - "logprob": -0.23254395, - "special": false, - "text": " of" - }, - { - "id": 7515, - "logprob": -0.4489746, - "special": false, - "text": " numbers" - }, - { - "id": 32, - "logprob": -0.6044922, - "special": false, - "text": "." - }, - { - "id": 446, - "logprob": -0.63964844, - "special": false, - "text": "\n\n " - }, - { - "id": 499, - "logprob": -1.1953125, - "special": false, - "text": " :" - }, - { - "id": 753, - "logprob": -0.03515625, - "special": false, - "text": "param" - }, - { - "id": 498, - "logprob": -0.06311035, - "special": false, - "text": " L" - }, - { - "id": 44, - "logprob": -0.003414154, - "special": false, - "text": ":" - }, - { - "id": 1682, - "logprob": -1.3310547, - "special": false, - "text": " List" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" + "generated_text": "\n " } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index bf0f5146..015912f8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5898438, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26586914, + "logprob": -0.21984863, "text": "_" }, { "id": 6009, - "logprob": -1.6347656, + "logprob": -1.2861328, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25219727, "text": "(" }, { "id": 62, - "logprob": -5.2382812, + "logprob": -4.8007812, "text": "L" }, { "id": 44, - "logprob": -3.0996094, + "logprob": -3.7949219, "text": ":" }, { "id": 1682, - "logprob": -1.1025391, + "logprob": -0.8046875, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22424316, "text": "[" }, { "id": 1808, - "logprob": -0.32226562, + "logprob": -0.46191406, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0253906, "text": "]):" } ], @@ -74,121 +74,13 @@ "text": "\n " }, { - "id": 442, - "logprob": -1.3134766, - "special": false, - "text": " return" - }, - { - "id": 11665, - "logprob": -0.10021973, - "special": false, - "text": " reduce" - }, - { - "id": 26, - "logprob": 0.0, - "special": false, - "text": "(" - }, - { - "id": 5962, - "logprob": 0.0, - "special": false, - "text": "lambda" - }, - { - "id": 816, - "logprob": 0.0, - "special": false, - "text": " x" - }, - { - "id": 30, - "logprob": 0.0, - "special": false, - "text": "," - }, - { - "id": 533, - "logprob": 0.0, - "special": false, - "text": " y" - }, - { - "id": 44, - "logprob": 0.0, - "special": false, - "text": ":" - }, - { - "id": 816, - "logprob": 0.0, - "special": false, - "text": " x" - }, - { - "id": 319, - "logprob": -0.42871094, - "special": false, - "text": " *" - }, - { - "id": 533, - "logprob": 0.0, - "special": false, - "text": " y" - }, - { - "id": 30, - "logprob": 0.0, - "special": false, - "text": "," - }, - { - "id": 498, - "logprob": 0.0, - "special": false, - "text": " L" - }, - { - "id": 27, - "logprob": 0.0, - "special": false, - "text": ")" - }, - { - "id": 1115, - "logprob": 0.0, - "special": false, - "text": " **" - }, - { - "id": 308, - "logprob": 0.0, - "special": false, - "text": " (" - }, - { - "id": 35, - "logprob": 0.0, - "special": false, - "text": "1" - }, - { - "id": 32, - "logprob": -0.31323242, - "special": false, - "text": "." - }, - { - "id": 34, - "logprob": 0.0, - "special": false, - "text": "0" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" + "generated_text": "\n " } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 46a21ed8..c9b5ab20 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -2,8 +2,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5820312, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26708984, + "logprob": -0.21826172, "text": "_" }, { "id": 6009, - "logprob": -1.6386719, + "logprob": -1.3085938, "text": "mean" }, { "id": 26, - "logprob": -0.22717285, + "logprob": -0.2548828, "text": "(" }, { "id": 62, - "logprob": -5.234375, + "logprob": -4.8007812, "text": "L" }, { "id": 44, - "logprob": -3.1015625, + "logprob": -3.7871094, "text": ":" }, { "id": 1682, - "logprob": -1.1083984, + "logprob": -0.81152344, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22644043, "text": "[" }, { "id": 1808, - "logprob": -0.32592773, + "logprob": -0.46313477, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0253906, "text": "]):" } ], @@ -70,74 +70,26 @@ "tokens": [ { "id": 284, - "logprob": -0.12817383, + "logprob": -0.046936035, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.9863281, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.7011719, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.2050781, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2668457, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.08465576, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.019012451, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.028625488, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29418945, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.3161621, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " }, { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -146,57 +98,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9375, "text": " ge" }, { "id": 21017, - "logprob": -7.59375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26953125, + "logprob": -0.21899414, "text": "_" }, { "id": 6009, - "logprob": -1.640625, + "logprob": -1.3105469, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25561523, "text": "(" }, { "id": 62, - "logprob": -5.234375, + "logprob": -4.8085938, "text": "L" }, { "id": 44, - "logprob": -3.1132812, + "logprob": -3.7890625, "text": ":" }, { "id": 1682, - "logprob": -1.1123047, + "logprob": -0.80615234, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22375488, "text": "[" }, { "id": 1808, - "logprob": -0.32299805, + "logprob": -0.46801758, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0253906, "text": "]):" } ], @@ -204,74 +156,26 @@ "tokens": [ { "id": 284, - "logprob": -0.12854004, + "logprob": -0.046447754, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.9897461, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.69970703, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.2050781, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2668457, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.08496094, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.019012451, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.029037476, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.2939453, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.31591797, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " }, { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -280,57 +184,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5859375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26586914, + "logprob": -0.2163086, "text": "_" }, { "id": 6009, - "logprob": -1.6347656, + "logprob": -1.2958984, "text": "mean" }, { "id": 26, - "logprob": -0.22766113, + "logprob": -0.2529297, "text": "(" }, { "id": 62, - "logprob": -5.2265625, + "logprob": -4.796875, "text": "L" }, { "id": 44, - "logprob": -3.0976562, + "logprob": -3.7910156, "text": ":" }, { "id": 1682, - "logprob": -1.1025391, + "logprob": -0.8076172, "text": " List" }, { "id": 77, - "logprob": -0.1427002, + "logprob": -0.22375488, "text": "[" }, { "id": 1808, - "logprob": -0.32592773, + "logprob": -0.46655273, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0234375, "text": "]):" } ], @@ -338,74 +242,26 @@ "tokens": [ { "id": 284, - "logprob": -0.13012695, + "logprob": -0.0463562, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.98046875, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.69921875, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.1992188, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2668457, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.083496094, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.01902771, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.029006958, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29248047, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.3161621, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " }, { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -414,57 +270,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5859375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26904297, + "logprob": -0.21862793, "text": "_" }, { "id": 6009, - "logprob": -1.6386719, + "logprob": -1.3095703, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25512695, "text": "(" }, { "id": 62, - "logprob": -5.234375, + "logprob": -4.796875, "text": "L" }, { "id": 44, - "logprob": -3.1132812, + "logprob": -3.7890625, "text": ":" }, { "id": 1682, - "logprob": -1.1074219, + "logprob": -0.79589844, "text": " List" }, { "id": 77, - "logprob": -0.14477539, + "logprob": -0.22692871, "text": "[" }, { "id": 1808, - "logprob": -0.3256836, + "logprob": -0.46801758, "text": "float" }, { "id": 10794, - "logprob": -2.8027344, + "logprob": -3.0097656, "text": "]):" } ], @@ -472,67 +328,19 @@ "tokens": [ { "id": 284, - "logprob": -0.12915039, + "logprob": -0.04638672, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.98535156, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.69921875, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.2011719, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.26708984, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.08502197, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.019012451, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.028625488, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29589844, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.31591797, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " } ] diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json index ef88926c..93724fe4 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -11,22 +11,22 @@ }, { "id": 13, - "logprob": -2.734375, + "logprob": -2.59375, "text": "," }, { "id": 8862, - "logprob": -3.6875, + "logprob": -3.5625, "text": " yellow" }, { "id": 13, - "logprob": -0.40234375, + "logprob": -0.44726562, "text": "," }, { "id": 209, - "logprob": -8.25, + "logprob": -8.0, "text": " " } ], @@ -52,7 +52,7 @@ }, { "id": 9830, - "logprob": -2.25, + "logprob": -2.03125, "special": false, "text": " colors" }, @@ -64,13 +64,13 @@ }, { "id": 329, - "logprob": -2.171875, + "logprob": -2.734375, "special": false, "text": " A" }, { "id": 1180, - "logprob": -2.046875, + "logprob": -2.0, "special": false, "text": " number" }, @@ -81,19 +81,19 @@ "text": " of" }, { - "id": 1027, - "logprob": -1.5546875, + "id": 253, + "logprob": -0.69140625, "special": false, - "text": " different" + "text": " the" }, { "id": 3295, - "logprob": -0.97265625, + "logprob": -0.8203125, "special": false, "text": " color" } ], "top_tokens": null }, - "generated_text": "blue, red, yellow, \nand blue colors. A number of different color" + "generated_text": "blue, red, yellow, \nand blue colors. A number of the color" } diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json index 5cacf3e9..40ec7e2f 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -26,13 +26,13 @@ }, { "id": 259, - "logprob": -0.4716797, + "logprob": -0.46948242, "special": false, "text": " " }, { "id": 261, - "logprob": -0.044677734, + "logprob": -0.15307617, "special": false, "text": "," }, @@ -56,7 +56,7 @@ }, { "id": 35622, - "logprob": -1.1630859, + "logprob": -1.2998047, "special": false, "text": " cloud" }, diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index f1007d6e..6d46e54d 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -21,7 +21,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap max_new_tokens=20, decoder_input_details=True, ) - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 2 assert response == generous_response_snapshot @@ -38,7 +38,7 @@ async def test_flash_starcoder_gptq_default_params( decoder_input_details=True, seed=0, ) - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 2 assert response == generous_response_snapshot diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index bc946de8..8548970a 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -47,7 +47,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): assert response.details.generated_tokens == 10 assert ( response.generated_text - == "blue, red, yellow, \nand blue colors. A number of different color" + == "blue, red, yellow, \nand blue colors. A number of the color" ) assert response == response_snapshot From 20ed7b598e4f58f5e6b2147d5bb09bf6a8d53cdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 16 Aug 2024 10:01:01 +0200 Subject: [PATCH 233/340] nix: try to reduce the number of Rust rebuilds (#2424) Try to reduce the number of router/launcher rebuilds by filtering sources. In this way, recompiles should only be triggered by changes in Cargo or Rust files. --- flake.lock | 16 ++++++++++ flake.nix | 30 ++++--------------- nix/crate-overrides.nix | 66 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 25 deletions(-) create mode 100644 nix/crate-overrides.nix diff --git a/flake.lock b/flake.lock index ef05c258..7c772377 100644 --- a/flake.lock +++ b/flake.lock @@ -579,6 +579,21 @@ "type": "github" } }, + "nix-filter": { + "locked": { + "lastModified": 1710156097, + "narHash": "sha256-1Wvk8UP7PXdf8bCCaEoMnOT1qe5/Duqgj+rL8sRQsSM=", + "owner": "numtide", + "repo": "nix-filter", + "rev": "3342559a24e85fc164b295c3444e8a139924675b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "nix-filter", + "type": "github" + } + }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -880,6 +895,7 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", + "nix-filter": "nix-filter", "nixpkgs": [ "tgi-nix", "nixpkgs" diff --git a/flake.nix b/flake.nix index 168c8649..cf05746a 100644 --- a/flake.nix +++ b/flake.nix @@ -4,6 +4,7 @@ url = "github:nix-community/crate2nix"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; + nix-filter.url = "github:numtide/nix-filter"; tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; @@ -17,6 +18,7 @@ { self, crate2nix, + nix-filter, nixpkgs, flake-utils, rust-overlay, @@ -44,6 +46,7 @@ }; inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; + crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; in { devShells.default = @@ -91,31 +94,8 @@ transformers vllm - cargoNix.workspaceMembers.text-generation-launcher.build - - (cargoNix.workspaceMembers.text-generation-router-v3.build.override { - crateOverrides = defaultCrateOverrides // { - aws-lc-rs = attrs: { - # aws-lc-rs does its own custom parsing of Cargo environment - # variables like DEP_.*_INCLUDE. However buildRustCrate does - # not use the version number, so the parsing fails. - postPatch = '' - substituteInPlace build.rs \ - --replace-fail \ - "assert!(!selected.is_empty()" \ - "// assert!(!selected.is_empty()" - ''; - }; - rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; - text-generation-router-v3 = attrs: { - # We need to do the src/source root dance so that the build - # has access to the protobuf file. - src = ./.; - postPatch = "cd backends/v3"; - buildInputs = [ protobuf ]; - }; - }; - }) + (cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }) + (cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; }) ]); venvDir = "./.venv"; diff --git a/nix/crate-overrides.nix b/nix/crate-overrides.nix new file mode 100644 index 00000000..343b3b25 --- /dev/null +++ b/nix/crate-overrides.nix @@ -0,0 +1,66 @@ +{ pkgs, nix-filter }: + +let + filter = nix-filter.lib; +in +with pkgs; +defaultCrateOverrides +// { + aws-lc-rs = attrs: { + # aws-lc-rs does its own custom parsing of Cargo environment + # variables like DEP_.*_INCLUDE. However buildRustCrate does + # not use the version number, so the parsing fails. + postPatch = '' + substituteInPlace build.rs \ + --replace-fail \ + "assert!(!selected.is_empty()" \ + "// assert!(!selected.is_empty()" + ''; + }; + rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; + + grpc-metadata = attrs: { + src = + filter { + root = ../backends/grpc-metadata; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-launcer = attrs: { + src = + filter { + root = ../launcher; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-router = attrs: { + src = + filter { + root = ../router; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-router-v3 = attrs: { + # We need to do the src/source root dance so that the build + # has access to the protobuf file. + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "backends/v3") (matchExt "rs")) + (and (inDirectory "proto") (matchExt "proto")) + ]; + }; + postPatch = "cd backends/v3"; + buildInputs = [ protobuf ]; + }; +} From df0e65089133c81bacb93c52742df59c9c1110d8 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 16 Aug 2024 12:43:08 +0200 Subject: [PATCH 234/340] Improve the Consuming TGI + Streaming docs. (#2412) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve the Consuming TGI docs. * Fix erronous update to . * add info about Open AI client. * More updates. * Apply suggestions from code review Co-authored-by: Erik Kaunismäki * Suggestions from Lucain. * Update Gradio snippet. * Up. * Apply suggestions from code review Co-authored-by: Lucain * Update docs/source/basic_tutorials/consuming_tgi.md Co-authored-by: Lucain * Up. * Apply suggestions from code review Co-authored-by: Omar Sanseviero * Up. * Up. * Doc review from Nico. * Doc review from Nico. x2 * Last nit --------- Co-authored-by: Erik Kaunismäki Co-authored-by: Lucain Co-authored-by: Omar Sanseviero --- docs/openapi.json | 2 +- docs/source/basic_tutorials/consuming_tgi.md | 196 +++++++++++------- docs/source/basic_tutorials/using_guidance.md | 2 +- docs/source/conceptual/streaming.md | 94 +++++---- 4 files changed, 174 insertions(+), 120 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index d0f78061..b37eb33a 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2094,4 +2094,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} +} \ No newline at end of file diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md index 4829ec7c..6e4ec49c 100644 --- a/docs/source/basic_tutorials/consuming_tgi.md +++ b/docs/source/basic_tutorials/consuming_tgi.md @@ -1,81 +1,125 @@ # Consuming Text Generation Inference -There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. +There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens. + +For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference). + +You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models. ## curl -After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: +After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec: + +```bash +curl localhost:8080/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + +For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes. ```bash curl 127.0.0.1:8080/generate \ -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -d '{ + "inputs":"What is Deep Learning?", + "parameters":{ + "max_new_tokens":20 + } +}' \ -H 'Content-Type: application/json' ``` +## Python -## Inference Client +### Inference Client -[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. -You can simply install `huggingface-hub` package with pip. +[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface. + +Install `huggingface_hub` package via pip. ```bash -pip install huggingface-hub +pip install huggingface_hub ``` -Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. +You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python ```python from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") -client.text_generation(prompt="Write a code for snake game") +client = InferenceClient( + base_url="http://localhost:8080/v1/", +) + +output = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, +) + +for chunk in output: + print(chunk.choices[0].delta.content) ``` -You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: +You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility). + +There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + +### OpenAI Client + +You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI. + +Install the OpenAI Python package via pip. + +```bash +pip install openai +``` ```python -for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): - print(token) +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:8080/v1/", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message) ``` -Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. +## UI -```python -output = client.text_generation(prompt="Meaning of life is", details=True) -print(output) - -# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..)) -``` - -You can see how to stream below. - -```python -output = client.text_generation(prompt="Meaning of life is", stream=True, details=True) -print(next(iter(output))) - -# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None) -``` - -You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) - - -## ChatUI - -ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces. - -To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served. - -``` -{ -// rest of the model config here -"endpoints": [{"url": "https://HOST:PORT/generate_stream"}] -} -``` - -![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) - -## Gradio +### Gradio Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. @@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference import gradio as gr from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") +client = InferenceClient(base_url="http://127.0.0.1:8080") def inference(message, history): partial_message = "" - for token in client.text_generation(message, max_new_tokens=20, stream=True): - partial_message += token + output = client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": message}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + partial_message += chunk.choices[0].delta.content yield partial_message gr.ChatInterface( inference, chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), - description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", + description="This is the demo for Gradio UI consuming TGI endpoint.", title="Gradio 🤝 TGI", examples=["Are tomatoes vegetables?"], retry_btn="Retry", @@ -110,20 +163,7 @@ gr.ChatInterface( ).queue().launch() ``` -The UI looks like this 👇 - -
- - -
- -You can try the demo directly here 👇 +You can check out the UI and try the demo directly here 👇