From e74bd41e0f279ab569cf6a65ac3e2cea50e80d39 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 30 Jun 2023 19:09:59 +0200 Subject: [PATCH 01/11] feat(server): add paged attention to flash models (#516) Closes #478 --- Dockerfile | 16 +- README.md | 4 +- integration-tests/models/test_flash_neox.py | 2 + launcher/src/main.rs | 44 +- proto/generate.proto | 12 + router/client/src/client.rs | 58 ++ router/client/src/sharded_client.rs | 21 + router/src/infer.rs | 16 +- router/src/main.rs | 42 +- router/src/queue.rs | 47 +- router/src/server.rs | 2 + server/Makefile-vllm | 13 + server/text_generation_server/cache.py | 4 +- .../models/causal_lm.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 216 +++---- .../custom_modeling/flash_neox_modeling.py | 227 +++---- .../custom_modeling/flash_rw_modeling.py | 355 +++++------ .../flash_santacoder_modeling.py | 202 +++---- .../models/custom_modeling/t5_modeling.py | 4 +- .../models/flash_causal_lm.py | 567 +++++++++++------- .../models/flash_llama.py | 6 +- .../models/flash_neox.py | 6 +- .../text_generation_server/models/flash_rw.py | 6 +- .../models/flash_santacoder.py | 13 +- server/text_generation_server/models/model.py | 6 + .../models/seq2seq_lm.py | 2 +- server/text_generation_server/server.py | 7 + server/text_generation_server/utils/tokens.py | 2 + .../text_generation_server/utils/weights.py | 31 +- 29 files changed, 1045 insertions(+), 888 deletions(-) create mode 100644 server/Makefile-vllm diff --git a/Dockerfile b/Dockerfile index 2a313c25..1a969383 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ /opt/conda/bin/conda clean -ya - # Build Flash Attention CUDA kernels FROM kernel-builder as flash-att-builder @@ -109,6 +108,16 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build vllm CUDA kernels +FROM kernel-builder as vllm-builder + +WORKDIR /usr/src + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm + # Text Generation Inference base image FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base @@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages -# Copy build artifacts from transformers builder +# Copy build artifacts from custom kernels builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/README.md b/README.md index 8c8d9773..b74d2617 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ to power LLMs api-inference widgets. - Tensor Parallelism for faster inference on multiple GPUs - Token streaming using Server-Sent Events (SSE) - [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput -- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures -- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures +- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index 1076126b..0289c61d 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle): return flash_neox_handle.client +@pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox(flash_neox, response_snapshot): response = await flash_neox.generate( @@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox, response_snapshot): assert response == response_snapshot +@pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): responses = await generate_load( diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2e2bc7a5..8497f807 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -115,12 +115,6 @@ struct Args { #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - /// The maximum allowed batch size during dynamic batching. - /// Using `max_batch_total_tokens` should be favored in general - /// as it's a finer way to control RAM usage. - #[clap(long, env)] - max_batch_size: Option, - /// This represents the ratio of waiting queries vs running queries where /// you want to start considering pausing the running queries to include the waiting /// ones into the same batch. @@ -134,6 +128,12 @@ struct Args { #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, + /// Limits the number of tokens for the prefill operation. + /// Since this operation take the most memory and is compute bound, it is interesting + /// to limit the number of requests that can be sent. + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. /// @@ -146,19 +146,12 @@ struct Args { /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` /// or a single query of `1000` tokens. /// - /// So you don't have to control that finely - /// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you - /// want maximum flexibility. However, for your users if they are asking for the full amount of - /// total tokens, they are likely to wait for a very long time to get a spot - /// in the batch (since they are going to be alone) so setting `max_batch_size` - /// and `max_total_tokens` can still be useful to prevent those long waiting times. - /// /// Overall this number should be the largest possible amount that fits the /// remaining memory (after the model is loaded). Since the actual memory overhead /// depends on other parameters like if you're using quantization, flash attention /// or the model implementation, text-generation-inference cannot infer this number /// automatically. - #[clap(default_value = "32000", long, env)] + #[clap(default_value = "16000", long, env)] max_batch_total_tokens: u32, /// This setting defines how many tokens can be passed before forcing the waiting @@ -180,9 +173,9 @@ struct Args { /// for end users. #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, - #[clap(default_value = "3000", long, short, env)] /// The port to listen on. + #[clap(default_value = "3000", long, short, env)] port: u16, /// The name of the socket for gRPC communication between the webserver @@ -329,6 +322,12 @@ fn shard_manager( // Copy current process env let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Use cuda allocator. It leads to less memory fragmentation + env.push(( + "PYTORCH_CUDA_ALLOC_CONF".into(), + "backend:cudaMallocAsync".into(), + )); + // Torch Distributed Env vars env.push(("RANK".into(), rank.to_string().into())); env.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -446,7 +445,7 @@ fn shard_manager( // We received a shutdown signal if *shutdown.lock().unwrap() { - p.terminate().unwrap(); + p.kill().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); tracing::info!("Shard {rank} terminated"); return; @@ -822,6 +821,10 @@ fn spawn_webserver( args.max_input_length.to_string(), "--max-total-tokens".to_string(), args.max_total_tokens.to_string(), + "--max-batch-prefill-tokens".to_string(), + args.max_batch_prefill_tokens.to_string(), + "--max-batch-total-tokens".to_string(), + args.max_batch_total_tokens.to_string(), "--waiting-served-ratio".to_string(), args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), @@ -834,15 +837,6 @@ fn spawn_webserver( args.model_id, ]; - // Deprecate max_batch_size - if let Some(max_batch_size) = args.max_batch_size { - argv.push("--max-batch-size".to_string()); - argv.push(max_batch_size.to_string()) - } else { - argv.push("--max-batch-total-tokens".to_string()); - argv.push(args.max_batch_total_tokens.to_string()) - } - // Model optional revision if let Some(ref revision) = args.revision { argv.push("--revision".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index a0f5a75e..5e061941 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -11,6 +11,8 @@ service TextGenerationService { rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); /// Remove requests from a cached batch rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Prefill batch and decode first token rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches @@ -192,3 +194,13 @@ message DecodeResponse { /// Next batch (cached) optional CachedBatch batch = 2; } + +message WarmupRequest { + /// Batch to warmup on + Batch batch = 1; + /// Maximum number of tokens that the client will send + uint32 max_total_tokens = 2; +} + +/// Empty response +message WarmupResponse {} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 81f023ef..b5e0ccc0 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi use crate::pb::generate::v1::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; +use std::cmp::min; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -94,6 +95,63 @@ impl Client { Ok(filtered_batch.batch) } + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + ) -> Result<()> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + + // Create requests + while n_tokens < max_prefill_tokens { + requests.push(Request { + id: 0, + // We truncate the input on the server side to be sure that it has the correct size + inputs: "_test ".to_string().repeat(max_input_length as usize), + truncate: min(max_input_length, max_prefill_tokens - n_tokens), + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + watermark: true, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 2, + stop_sequences: vec![], + ignore_eos_token: false, + }), + prefill_logprobs: true, + }); + n_tokens += max_input_length; + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_total_tokens, + }) + .inject_context(); + self.stub.warmup(request).await?.into_inner(); + Ok(()) + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index b81eed46..9dd173a0 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -87,6 +87,27 @@ impl ShardedClient { join_all(futures).await.pop().unwrap() } + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + ) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + }) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/src/infer.rs b/router/src/infer.rs index f738f986..d0d22d3b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -45,6 +45,7 @@ impl Infer { client: ShardedClient, validation: Validation, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, max_concurrent_requests: usize, @@ -61,6 +62,7 @@ impl Infer { tokio::spawn(batching_task( client, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, queue.clone(), @@ -240,9 +242,11 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, queue: Queue, @@ -257,8 +261,9 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = - queue.next_batch(None, max_batch_total_tokens).await + while let Some((mut entries, batch, span)) = queue + .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) + .await { let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) .instrument(span) @@ -284,11 +289,12 @@ async fn batching_task( Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; - let token_budget = max_batch_total_tokens - batch_max_tokens; + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = - queue.next_batch(min_size, token_budget).await + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_batch_prefill_tokens, token_budget) + .await { // Tracking metrics if min_size.is_some() { diff --git a/router/src/main.rs b/router/src/main.rs index 7bbb6477..47d48e3f 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -32,10 +32,10 @@ struct Args { max_input_length: usize, #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - #[clap(long, env)] - max_batch_size: Option, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, #[clap(default_value = "32000", long, env)] max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] @@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> { max_stop_sequences, max_input_length, max_total_tokens, - max_batch_size, waiting_served_ratio, - mut max_batch_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, max_waiting_tokens, port, master_shard_uds_path, @@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); - if let Some(max_batch_size) = max_batch_size { - tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); - max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; - tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); - } - if tokenizer.is_none() { tracing::warn!( "Could not find a fast tokenizer implementation for {tokenizer_name}" @@ -161,10 +155,16 @@ fn main() -> Result<(), std::io::Error> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } - }), + false => get_model_info(&tokenizer_name, &revision, authorization_token) + .await + .unwrap_or_else(|| { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + } + }), }; // if pipeline-tag == text-generation we default to return_full_text = true @@ -190,6 +190,17 @@ fn main() -> Result<(), std::io::Error> { .info() .await .expect("Unable to get shard info"); + + // Warmup model + tracing::info!("Warming up model"); + sharded_client + .warmup( + max_input_length as u32, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + .expect("Unable to warmup model"); tracing::info!("Connected"); // Binds on localhost @@ -206,6 +217,7 @@ fn main() -> Result<(), std::io::Error> { max_input_length, max_total_tokens, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, sharded_client, @@ -219,7 +231,7 @@ fn main() -> Result<(), std::io::Error> { ngrok_username, ngrok_password, ) - .await; + .await; Ok(()) }) } diff --git a/router/src/queue.rs b/router/src/queue.rs index 6d1d4d12..48e483a1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -58,6 +58,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, + prefill_token_budget: u32, token_budget: u32, ) -> Option { // Create response channel @@ -67,6 +68,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, + prefill_token_budget, token_budget, response_sender, span: Span::current(), @@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver span.in_scope(|| { - let next_batch = state.next_batch(min_size, token_budget); + let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), @@ -140,7 +143,12 @@ impl State { } // Get the next batch - fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { + fn next_batch( + &mut self, + min_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { if self.entries.is_empty() { return None; } @@ -184,7 +192,9 @@ impl State { decode_tokens += entry.request.stopping_parameters.max_new_tokens; - if (prefill_tokens + decode_tokens) > token_budget { + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens) > token_budget + { // Entry is over budget // Add it back to the front self.entries.push_front((id, entry)); @@ -259,6 +269,7 @@ enum QueueCommand { Append(Box, Span), NextBatch { min_size: Option, + prefill_token_budget: u32, token_budget: u32, response_sender: oneshot::Sender>, span: Span, @@ -328,8 +339,8 @@ mod tests { fn test_next_batch_empty() { let mut state = State::new(false); - assert!(state.next_batch(None, 1).is_none()); - assert!(state.next_batch(Some(1), 1).is_none()); + assert!(state.next_batch(None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), 1, 1).is_none()); } #[test] @@ -340,7 +351,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -356,7 +367,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), 2).is_none()); + assert!(state.next_batch(Some(2), 2, 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -372,7 +383,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -385,7 +396,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -408,8 +419,8 @@ mod tests { async fn test_queue_next_batch_empty() { let queue = Queue::new(false); - assert!(queue.next_batch(None, 1).await.is_none()); - assert!(queue.next_batch(Some(1), 1).await.is_none()); + assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); } #[tokio::test] @@ -420,7 +431,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -433,11 +444,11 @@ mod tests { queue.append(entry3); // Not enough requests pending - assert!(queue.next_batch(Some(2), 2).await.is_none()); + assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); // Not enough token budget - assert!(queue.next_batch(Some(1), 0).await.is_none()); + assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -453,7 +464,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -462,7 +473,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -476,6 +487,6 @@ mod tests { let (entry, _) = default_entry(); queue.append(entry); - assert!(queue.next_batch(None, 1).await.is_none()); + assert!(queue.next_batch(None, 1, 1).await.is_none()); } } diff --git a/router/src/server.rs b/router/src/server.rs index dd8bc874..ee96ead6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -514,6 +514,7 @@ pub async fn run( max_input_length: usize, max_total_tokens: usize, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, client: ShardedClient, @@ -582,6 +583,7 @@ pub async fn run( client, validation, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_concurrent_requests, diff --git a/server/Makefile-vllm b/server/Makefile-vllm new file mode 100644 index 00000000..af750733 --- /dev/null +++ b/server/Makefile-vllm @@ -0,0 +1,13 @@ +vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9 + +vllm: + # Clone vllm + git clone https://github.com/OlivierDehaene/vllm.git + +build-vllm: vllm + cd vllm && git fetch && git checkout $(vllm_commit) + cd vllm && python setup.py build + +install-vllm: build-vllm + pip uninstall vllm -y || true + cd vllm && python setup.py install \ No newline at end of file diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 5556529c..79fcd3aa 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -22,7 +22,9 @@ class Cache: del batch def clear(self): - self.cache.clear() + keys = list(self.cache.keys()) + for k in keys: + self.delete(k) def __len__(self): return len(self.cache.keys()) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba0853f5..6d47c6eb 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -122,7 +122,7 @@ class CausalLMBatch(Batch): position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 993e1e2a..07765e88 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,12 +23,16 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda import dropout_layer_norm +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -106,7 +110,7 @@ class FlashLlamaAttention(torch.nn.Module): prefix=f"{prefix}.rotary_emb", weights=weights ) - self.softmax_scale = self.head_size ** (-0.5) + self.softmax_scale = self.head_size**-0.5 self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi( @@ -122,20 +126,22 @@ class FlashLlamaAttention(torch.nn.Module): weights=weights, bias=False, ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -144,23 +150,25 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = qkv[:, 1:] + vllm_cache_ops.reshape_and_cache( + qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + ) - # output - attn_output = torch.empty_like(qkv[:, 0]) + # output tensor + attn_output = torch.empty_like(qkv[:, 0]) + + # Prefill + if start_seq_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -173,31 +181,19 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = qkv[:, 1:] - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + qkv[:, 0], + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -265,14 +261,13 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -281,14 +276,13 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) # faster post attention rms norm @@ -333,40 +327,18 @@ class FlashLlamaModel(torch.nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -380,34 +352,18 @@ class FlashLlamaModel(torch.nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - past_key_values[:, i], - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashLlamaForCausalLM(torch.nn.Module): @@ -423,31 +379,29 @@ class FlashLlamaForCausalLM(torch.nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.model( + ) -> torch.Tensor: + hidden_states = self.model( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9c1020a5..9049878a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -25,11 +25,15 @@ from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module): self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = qkv[:, 1:] + vllm_cache_ops.reshape_and_cache( + qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + ) - # output - attn_output = torch.empty_like(qkv[:, 0]) + # output tensor + attn_output = torch.empty_like(qkv[:, 0]) + + # Prefill + if start_seq_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = qkv[:, 1:] - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + qkv[:, 0], + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( @@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - past_key_values[:, i], - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.final_layer_norm(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): @@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.gpt_neox( + ) -> torch.Tensor: + hidden_states = self.gpt_neox( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fa35c359..44aa7cb1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -4,11 +4,15 @@ import torch.distributed from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + if self.num_heads_kv == 1: + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) + else: + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) @@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = kv - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) + vllm_cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output + attn_output = torch.empty_like(query) + + # Prefill + if start_seq_prefill is not None: + if self.num_heads_kv == 1: + # Expand to query shape + kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv - # Expand to query shape - kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + self.kv_head_mapping = torch.arange( + 0, self.num_groups, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_heads) + def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) + vllm_cache_ops.reshape_and_cache( + kv[:, :, 0].contiguous(), + kv[:, :, 1].contiguous(), + kv_cache[0], + kv_cache[1], + slots, + ) + + # output + attn_output = torch.empty_like(query) + # Prefill - if prefill: - # Copy to layer past - layer_past[...] = kv + if start_seq_prefill is not None: # Expand to query shape kv = ( kv.unsqueeze(2) @@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module): .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) ) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv - # Expand to query shape - kv = ( - layer_past.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(kv, dim=2, index=0), - torch.select(kv, dim=2, index=1), + # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense( @@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module): ln_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) mlp_output = self.mlp(ln_hidden_states) @@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) # MLP. @@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for layer_id in range(config.num_hidden_layers) ] ) - self.cache_size = ( - 2, - self.h[0].self_attention.num_heads_kv, - self.h[0].self_attention.head_size, - ) + self.cache_size = self.h[0].self_attention.num_heads_kv elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ @@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for layer_id in range(config.num_hidden_layers) ] ) - self.cache_size = ( - self.h[0].self_attention.num_groups, - 2, - self.h[0].self_attention.head_size, - ) + self.cache_size = self.h[0].self_attention.num_groups else: raise NotImplementedError( f"model_type {config.model_type} is not supported." @@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.h), - *self.cache_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( @@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - torch.select(past_key_values, dim=1, index=i), - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.h), - *self.cache_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashRWForCausalLM(FlashRWPreTrainedModel): @@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.transformer( + ) -> torch.Tensor: + hidden_states = self.transformer( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4eb0034d..04eedef7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -3,11 +3,15 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.c_attn(hidden_states) @@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) + vllm_cache_ops.reshape_and_cache( + key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output + attn_output = torch.empty_like(query) + # Prefill - if prefill: - # Copy to layer past - layer_past[...] = key_value + if start_seq_prefill is not None: # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = key_value - # Expand from 1 to num_heads - key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(key_value, dim=1, index=0), - torch.select(key_value, dim=1, index=1), + # kv_cache[1] => [num_blocks, 1, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -361,27 +357,25 @@ class Block(nn.Module): self, hidden_states, residual, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: torch.distributed.all_reduce(hidden_states, group=self.process_group) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_zeros( - (len(input_ids), len(self.h), 2, 1, self.head_size) - ) - # Decode - else: - prefill = False - residual = None for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, residual, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - torch.select(past_key_values, dim=1, index=i), - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashSantacoderForCausalLM(nn.Module): @@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.transformer( + ) -> torch.Tensor: + hidden_states = self.transformer( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 12679e9d..19deca86 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): try: self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) except RuntimeError: - self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) + self.shared = TensorParallelEmbedding( + prefix="encoder.embed_tokens", weights=weights + ) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ecea998e..94b14f85 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,11 +1,14 @@ +import math +import itertools import torch import torch.distributed import numpy as np from dataclasses import dataclass +from loguru import logger from opentelemetry import trace -from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel +from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model @@ -20,6 +23,92 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke tracer = trace.get_tracer(__name__) +BLOCK_SIZE = 16 +# Will be set in warmup +CACHE_MANAGER: Optional["CacheManager"] = None + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate(self, batch: "FlashCausalLMBatch"): + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= batch.blocks + ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[: batch.blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(batch), batch.max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks : cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + batch.needed_blocks_slots = None + batch.block_tables = block_tables + batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device) + batch.slots = torch.concat(slots).to(batch.input_ids.device) + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + @dataclass class FlashCausalLMBatch(Batch): @@ -32,23 +121,29 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor - # Indices to copy present to the correct indices is the pre-allocated past key values - past_present_indices: torch.Tensor - - # tensor of length b holding starting offset of each sequence - start_seq: torch.Tensor - # tensor of length b holding ending offset of each sequence - end_seq: torch.Tensor # tensor of length b holding starting offset of each sequence, only used in prefill start_seq_prefill: Optional[torch.Tensor] # tensor of length b holding ending offset of each sequence, only used in prefill end_seq_prefill: Optional[torch.Tensor] - # tensor of length b holding starting offset of each query sequence, only used in decode - start_seq_q: Optional[torch.Tensor] - # tensor of length b holding ending offset of each query sequence, only used in decode - end_seq_q: Optional[torch.Tensor] - # past key values, only used in decode - past_key_values: Optional[torch.Tensor] + + # Paged Attention values + + # Set when creating the batch + # CPU tensor of length b indicating the start of each sequence in slots + start_slots: torch.Tensor + # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode + slot_indices: torch.Tensor + # List of tuple of ints representing the number of blocks and slots needed by each sequence + needed_blocks_slots: Optional[List[Tuple[int, int]]] + + # Set in prefill by the CacheManager + # list of length b of list of length s_i // block_size + block_tables: Optional[List[List[int]]] + # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences + block_tables_tensor: Optional[torch.Tensor] + # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences + slots: Optional[torch.Tensor] + max_seqlen: int # Prefill metadata tensors to efficiently compute logprobs @@ -62,6 +157,7 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + input_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -69,15 +165,17 @@ class FlashCausalLMBatch(Batch): next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] - # Maximum number of tokens this batch will grow to - max_tokens: int + # Number of blocks in this batch + blocks: int + # Maximum number of blocks + max_blocks: int def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), - max_tokens=self.max_tokens, + max_tokens=self.blocks * BLOCK_SIZE, ) @classmethod @@ -99,12 +197,11 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] - past_present_indices = [] - start_seq = [] - end_seq = [] start_seq_prefill = [] end_seq_prefill = [] - max_seqlen = 0 + needed_blocks_slots = [] + start_slots = [] + slot_indices = [] input_lengths = [] prefix_offsets = [] @@ -126,7 +223,10 @@ class FlashCausalLMBatch(Batch): cumulative_max_length = 0 prefill_out_cumulative_length = 0 + blocks = 0 + max_seqlen = 0 max_length = 0 + max_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( @@ -138,7 +238,6 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) - max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) prefix_offsets.append(input_length - 5) @@ -153,8 +252,6 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs start_seq_prefill.append(cumulative_length) end_seq_prefill.append(cumulative_length + input_length) - start_seq.append(cumulative_max_length) - end_seq.append(cumulative_max_length + input_length) next_token_chooser_parameters.append(r.parameters) @@ -164,6 +261,21 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + # Paged attention + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + blocks += needed_blocks + needed_blocks_slots.append((needed_blocks, total_tokens)) + start_slots.append(cumulative_max_length) + + request_slot_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + slot_indices.append(request_slot_indices) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -184,22 +296,17 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - request_past_present_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - past_present_indices.append(request_past_present_indices) - # Update - # Remove one as the first token des not have a past cumulative_length += input_length - cumulative_max_length += input_length + max_new_tokens - 1 + cumulative_max_length += total_tokens + max_seqlen = max(max_seqlen, input_length) + max_blocks = max(max_blocks, needed_blocks) max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device ) + start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -212,34 +319,28 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) - end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) - - past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) - - start_seq_prefill = torch.tensor( - start_seq_prefill, device=device, dtype=torch.int32 - ) - end_seq_prefill = torch.tensor( - end_seq_prefill, device=device, dtype=torch.int32 - ) + slot_indices = torch.cat(slot_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] + slot_indices = slot_indices[0] - past_present_indices = past_present_indices[0] - - start_seq_prefill = start_seq - end_seq_prefill = end_seq + start_seq_prefill = torch.tensor( + start_seq_prefill, device=device, dtype=torch.int32 + ) + end_seq_prefill = torch.tensor( + end_seq_prefill, device=device, dtype=torch.int32 + ) + position_ids = position_ids.to(device) + slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - past_present_indices = torch.tensor( - past_present_indices, device=device, dtype=torch.int64 + input_lengths_tensor = torch.tensor( + input_lengths, dtype=torch.int32, device=device ) if all_prefill_logprobs: @@ -262,26 +363,28 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=start_seq_prefill, end_seq_prefill=end_seq_prefill, - start_seq_q=None, - end_seq_q=None, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=needed_blocks_slots, + block_tables=None, + block_tables_tensor=None, + slots=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - past_key_values=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=cumulative_max_length, + blocks=blocks, + max_blocks=max_blocks, ) @tracer.start_as_current_span("filter") @@ -294,28 +397,24 @@ class FlashCausalLMBatch(Batch): device = self.input_ids.device - # Cumulative length - cumulative_max_length = 0 - # New values after filtering requests_idx_mapping = {} # Used to index into tensors indices = [] - # past indices to keep - past_indices = torch.zeros( - self.past_key_values.shape[0], dtype=torch.bool, device=device + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device ) # Create on CPU to only move to GPU once instead of at every copy - start_seq = torch.empty(len(request_ids), dtype=torch.int32) - end_seq = torch.empty(len(request_ids), dtype=torch.int32) - start_seq_q = self.start_seq_q[: len(request_ids)] - end_seq_q = self.end_seq_q[: len(request_ids)] + slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_seqlen = 0 requests = [] + start_slots = [] + block_tables = [] all_input_ids = [] input_lengths = [] @@ -324,6 +423,11 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] + blocks = 0 + max_blocks = 0 + # Cumulative length + cumulative_max_length = 0 + for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) @@ -348,28 +452,51 @@ class FlashCausalLMBatch(Batch): stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + request_block_table = self.block_tables[idx] + blocks += len(request_block_table) + block_tables.append(request_block_table) + start_slots.append(cumulative_max_length) + # Copy to tensor (CPU) - start_seq[i] = cumulative_max_length - end_seq[i] = cumulative_max_length + request_input_length + slot_indices[i] = cumulative_max_length + request_input_length - 1 # Set slice - past_indices[ - self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 + slot_filtering_indices[ + self.start_slots[idx] : self.start_slots[idx] + + request_input_length + + remaining_tokens + - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 + max_blocks = max(max_blocks, len(request_block_table)) + + global CACHE_MANAGER + block_indices_to_free = [] + # Iterate on all requests + for i, r in enumerate(self.requests): + # Filter requests that are not part of the new batch + if r.id not in requests_idx_mapping.keys(): + block_indices_to_free.extend(self.block_tables[i]) + # Free blocks + CACHE_MANAGER.free(block_indices_to_free) + # Needed to avoid dropping blocks when the batches will go out of scope + self.block_tables = None + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] + block_tables_tensor = self.block_tables_tensor[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) - past_key_values = self.past_key_values[past_indices] + + start_slots = torch.tensor(start_slots, dtype=torch.int64) # Move to GPU now that we have the whole tensor - start_seq = start_seq.to(device) - end_seq = end_seq.to(device) - past_present_indices = end_seq - 1 + slot_indices = slot_indices.to(device) return FlashCausalLMBatch( batch_id=self.batch_id, @@ -377,26 +504,28 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=None, end_seq_prefill=None, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - past_key_values=past_key_values, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=cumulative_max_length, + blocks=blocks, + max_blocks=max_blocks, ) @classmethod @@ -406,22 +535,46 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} - total_batch_size = sum([len(b) for b in batches]) - - dtype = batches[0].past_key_values.dtype - device = batches[0].input_ids.device + blocks = 0 + total_batch_size = 0 + total_slots = 0 + max_blocks = 0 + max_length = 0 + max_seqlen = 0 + for b in batches: + total_batch_size += len(b) + total_slots += len(b.slots) + blocks += b.blocks + max_blocks = max(max_blocks, b.max_blocks) + max_seqlen = max(max_seqlen, b.max_seqlen) + max_length = max( + max_length, + max( + input_length + + stopping_criteria.max_new_tokens + - stopping_criteria.current_tokens + for input_length, stopping_criteria in zip( + b.input_lengths, b.stopping_criterias + ) + ), + ) input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - start_seq = batches[0].start_seq.new_empty(total_batch_size) - end_seq = batches[0].end_seq.new_empty(total_batch_size) - start_seq_q = torch.arange( - 0, total_batch_size, device=device, dtype=torch.int32 + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + block_tables_tensor = batches[0].block_tables_tensor.new_zeros( + (total_batch_size, max_blocks) + ) + all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( + (total_batch_size, max_length) ) - end_seq_q = start_seq_q + 1 - max_seqlen = 0 - past_key_values = [] + start_slots = [] + block_tables = [] all_input_ids = [] input_lengths = [] @@ -433,8 +586,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 - max_tokens = 0 - max_length = 0 + cumulative_slots = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -448,16 +600,27 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + slots[slots_start_index:slots_end_index] = batch.slots - start_seq[start_index:end_index] = batch.start_seq + max_tokens - end_seq[start_index:end_index] = batch.end_seq + max_tokens + all_input_ids_tensor[ + start_index:end_index, : batch.all_input_ids_tensor.shape[1] + ] = batch.all_input_ids_tensor[:, :max_length] - max_seqlen = max(max_seqlen, batch.max_seqlen) + block_tables_tensor[ + start_index:end_index, : batch.block_tables_tensor.shape[1] + ] = batch.block_tables_tensor[:, :max_blocks] + start_slots.append(batch.start_slots + cumulative_slots) + + block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -466,73 +629,59 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) - past_key_values.append(batch.past_key_values) # Update cumulative_batch_size += len(batch) - max_tokens += batch.max_tokens - max_length = max( - max_length, - max( - input_length - + stopping_criteria.max_new_tokens - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - batch.input_lengths, batch.stopping_criterias - ) - ), - ) + cumulative_slots += len(batch.slots) - past_key_values = torch.cat(past_key_values, dim=0) - past_present_indices = end_seq - 1 - - all_input_ids_tensor = torch.zeros( - (total_batch_size, max_length), dtype=torch.int64, device=device - ) - - cumulative_batch_size = 0 - for i, batch in enumerate(batches): - start_index = cumulative_batch_size - end_index = cumulative_batch_size + len(batch) - - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:, :max_length] - - cumulative_batch_size += len(batch) + start_slots = torch.concat(start_slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype=dtype, device=device + next_token_chooser_parameters, + dtype=batches[0].next_token_chooser.dtype, + device=batches[0].next_token_chooser.device, ) + # Needed to avoid dropping blocks when the batches will go out of scope + for b in batches: + b.block_tables = None + return FlashCausalLMBatch( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=None, end_seq_prefill=None, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - past_key_values=past_key_values, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + blocks=blocks, + max_blocks=max_blocks, ) + def __del__(self): + if self.block_tables is not None and self.block_tables: + global CACHE_MANAGER + # Free blocks + CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) + def __len__(self): return len(self.requests) @@ -540,32 +689,19 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, - model_cls: Type[PreTrainedModel], - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + num_layers: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + rank: int = 0, + world_size: int = 1, ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashCausalLM is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = model_cls.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ).to(device) + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_size = head_size super(FlashCausalLM, self).__init__( model=model, @@ -573,12 +709,38 @@ class FlashCausalLM(Model): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): + global CACHE_MANAGER + + torch.cuda.empty_cache() + try: + CACHE_MANAGER = CacheManager( + # Adds some wiggle room + math.ceil(max_total_tokens / BLOCK_SIZE) + 10, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + _, batch = self.generate_token(batch) + except Exception as e: + logger.exception( + f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " + f"prefill tokens. " + f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" + ) + raise e + del batch + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False @@ -588,28 +750,27 @@ class FlashCausalLM(Model): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq: torch.Tensor, - end_seq: torch.Tensor, - start_seq_q: Optional[torch.Tensor], - end_seq_q: Optional[torch.Tensor], + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, - past_present_indices: torch.Tensor, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + global CACHE_MANAGER + # Model Forward return self.model.forward( input_ids=input_ids, position_ids=position_ids, - start_seq=start_seq, - end_seq=end_seq, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_seq_prefill=start_seq_prefill, + end_seq_prefill=end_seq_prefill, + kv_cache=CACHE_MANAGER.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, max_s=max_s, - past_present_indices=past_present_indices, - past_key_values=past_key_values, - pre_allocate_past_size=pre_allocate_past_size, lm_head_indices=lm_head_indices, ) @@ -617,31 +778,22 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - prefill = batch.past_key_values is None + prefill = batch.start_seq_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - if prefill: - # Ask to pre-allocate kv to its max size - # == Sum over batch size (number of tokens + max_new_tokens) - batch size - pre_allocate_past_size = batch.max_tokens - start_seq = batch.start_seq_prefill - end_seq = batch.end_seq_prefill - else: - pre_allocate_past_size = None - start_seq = batch.start_seq - end_seq = batch.end_seq + if batch.needed_blocks_slots: + # Allocate blocks to this batch + CACHE_MANAGER.allocate(batch) - out, present = self.forward( + out = self.forward( batch.input_ids, batch.position_ids, - start_seq, - end_seq, - batch.start_seq_q, - batch.end_seq_q, + batch.start_seq_prefill, + batch.end_seq_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, batch.max_seqlen, - batch.past_present_indices, - batch.past_key_values, - pre_allocate_past_size, batch.prefill_head_indices, ) @@ -662,12 +814,8 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - # Create batch.start_seq_q and batch.end_seq_q for decode - batch.start_seq_q = torch.arange( - 0, len(batch), device=self.device, dtype=torch.int32 - ) - batch.end_seq_q = batch.start_seq_q + 1 next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1] # We do not need start_seq_prefill and end_seq_prefill anymore batch.start_seq_prefill = None batch.end_seq_prefill = None @@ -731,8 +879,8 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 - batch.past_present_indices = batch.end_seq - batch.end_seq = batch.end_seq + 1 + batch.input_lengths_tensor += 1 + batch.slot_indices += 1 if prefill and prefill_logprobs: # Get prefill logprobs @@ -755,7 +903,6 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.all_input_ids_tensor, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, next_token_ids, @@ -770,7 +917,6 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, - all_input_ids_tensor, do_sample, seed, next_token_id, @@ -845,19 +991,20 @@ class FlashCausalLM(Model): generations.append(generation) - new_input_length = input_length + 1 - # Update values - batch.input_lengths[i] = new_input_length + batch.input_lengths[i] = input_length + 1 batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids + if stopped: + del batch + # No need to return a batch if we know that all requests stopped + return generations, None + batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 - batch.past_key_values = present - # No need to return a batch if we know that all requests stopped - return generations, batch if not stopped else None + return generations, batch diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a80d58cb..2c59f01e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashLlama, self).__init__( model=model, tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_heads, + head_size=model.model.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 4847571d..e64af0c6 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM): model = FlashGPTNeoXForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashNeoXSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.gpt_neox.layers), + num_kv_heads=model.gpt_neox.num_heads, + head_size=model.gpt_neox.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 5f963bfb..a55f9118 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM): model = FlashRWForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashRWSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.transformer.h), + num_kv_heads=model.transformer.cache_size, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a71c0061..ef202785 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -52,17 +52,22 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group, - aliases = {"transformer.wte.weight": ["lm_head.weight"]} + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) model = FlashSantacoderForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashSantacoderSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.transformer.h), + num_kv_heads=1, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 6b8472a5..f8460fc2 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -22,6 +22,9 @@ class Model(ABC): rank: int = 0, world_size: int = 1, ): + if torch.cuda.is_available(): + torch.cuda.set_per_process_memory_fraction(1.0) + self.model = model.eval() self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) @@ -55,6 +58,9 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError + def warmup(self, batch: B, max_total_tokens: int): + self.generate_token(batch) + def decode_token( self, all_input_ids: List[int], diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3ad5698c..999b6637 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch): read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e1bd8412..6cc5beeb 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -53,6 +53,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + async def Warmup(self, request, context): + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + self.model.warmup(batch, request.max_total_tokens) + return generate_pb2.WarmupResponse() + async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e6e512bc..b83af591 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser: self.seeds = seeds self.do_sample = do_sample + self.dtype = dtype + self.device = device def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): if self.watermark_processor is not None: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d371834..83d9df68 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -5,7 +5,14 @@ import torch class Weights: - def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + aliases: Optional[Dict[str, List[str]]] = None, + ): routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: @@ -43,7 +50,7 @@ class Weights: return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename, tensor_name= self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -94,12 +101,20 @@ class Weights: def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -118,7 +133,9 @@ class Weights: try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) From 3b0c979efcccd8ca51f59f1f982bfbbc842d06c9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 30 Jun 2023 20:07:49 +0200 Subject: [PATCH 02/11] feat(router): arg validation (#519) --- launcher/src/main.rs | 4 ++-- router/src/main.rs | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8497f807..9d6cd4dd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -101,7 +101,7 @@ struct Args { /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. /// Please note that some models have a finite range of sequence they can handle. - #[clap(default_value = "1000", long, env)] + #[clap(default_value = "1024", long, env)] max_input_length: usize, /// This is the most important value to set as it defines the "memory budget" @@ -112,7 +112,7 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - #[clap(default_value = "1512", long, env)] + #[clap(default_value = "2048", long, env)] max_total_tokens: usize, /// This represents the ratio of waiting queries vs running queries where diff --git a/router/src/main.rs b/router/src/main.rs index 47d48e3f..f782be09 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -28,15 +28,15 @@ struct Args { max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, - #[clap(default_value = "1000", long, env)] + #[clap(default_value = "1024", long, env)] max_input_length: usize, - #[clap(default_value = "1512", long, env)] + #[clap(default_value = "2048", long, env)] max_total_tokens: usize, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, - #[clap(default_value = "32000", long, env)] + #[clap(default_value = "16000", long, env)] max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, @@ -97,8 +97,18 @@ fn main() -> Result<(), std::io::Error> { ngrok_password, } = args; + // Validate args + if max_input_length as u32 > max_batch_prefill_tokens { + panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")); + } + if max_total_tokens as u32 > max_batch_total_tokens { + panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")); + } if validation_workers == 0 { - panic!("validation_workers must be > 0"); + panic!("`validation_workers` must be > 0"); } // CORS allowed origins From ecf6dc3a5a31c1b0e1ed48988ddf2416b5e35660 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 30 Jun 2023 20:30:09 +0200 Subject: [PATCH 03/11] feat: Add the option to force another dtype than `f16`. (#513) --- launcher/src/main.rs | 32 +++++++++++ server/text_generation_server/cli.py | 15 ++++- .../text_generation_server/models/__init__.py | 55 +++++++++++++++++-- server/text_generation_server/models/bloom.py | 3 +- .../models/causal_lm.py | 3 +- .../models/flash_llama.py | 3 +- .../models/flash_neox.py | 3 +- .../text_generation_server/models/flash_rw.py | 3 +- .../models/flash_santacoder.py | 3 +- .../models/galactica.py | 3 +- .../text_generation_server/models/gpt_neox.py | 3 +- server/text_generation_server/models/opt.py | 3 +- server/text_generation_server/models/rw.py | 3 +- .../models/santacoder.py | 3 +- .../models/seq2seq_lm.py | 3 +- server/text_generation_server/models/t5.py | 3 +- server/text_generation_server/server.py | 10 +++- 17 files changed, 130 insertions(+), 21 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9d6cd4dd..51131f42 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Dtype { + Float16, + BFloat16, +} + +impl std::fmt::Display for Dtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Dtype::Float16 => { + write!(f, "float16") + } + Dtype::BFloat16 => { + write!(f, "bfloat16") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -71,6 +91,10 @@ struct Args { #[clap(long, env, value_enum)] quantize: Option, + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. + #[clap(long, env, value_enum)] + dtype: Option, + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// encouraged when loading a model with custom code to ensure no malicious code has been /// contributed in a newer revision. @@ -258,6 +282,7 @@ fn shard_manager( model_id: String, revision: Option, quantize: Option, + dtype: Option, trust_remote_code: bool, uds_path: String, rank: usize, @@ -307,6 +332,11 @@ fn shard_manager( shard_argv.push(quantize.to_string()) } + if let Some(dtype) = dtype { + shard_argv.push("--dtype".to_string()); + shard_argv.push(dtype.to_string()) + } + // Model optional revision if let Some(revision) = revision { shard_argv.push("--revision".to_string()); @@ -743,6 +773,7 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; + let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; @@ -753,6 +784,7 @@ fn spawn_shards( model_id, revision, quantize, + dtype, trust_remote_code, uds_path, rank, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index aeb1f13b..3463049a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -16,12 +16,18 @@ class Quantization(str, Enum): gptq = "gptq" +class Dtype(str, Enum): + float16 = "float16" + bloat16 = "bfloat16" + + @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -64,7 +70,14 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value - server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path) + dtype = None if dtype is None else dtype.value + if dtype is not None and quantize is not None: + raise RuntimeError( + "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." + ) + server.serve( + model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + ) @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2abde685..e45e198a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -100,11 +100,25 @@ def get_model( revision: Optional[str], sharded: bool, quantize: Optional[str], + dtype: Optional[str], trust_remote_code: bool, ) -> Model: + if dtype is None: + dtype = torch.float16 + elif dtype == "float16": + dtype = torch.float16 + elif dtype == "bfloat16": + dtype = torch.bfloat16 + else: + raise RuntimeError(f"Unknown dtype {dtype}") + if "facebook/galactica" in model_id: return GalacticaSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + dtypetrust_remote_code=trust_remote_code, ) if model_id.startswith("bigcode/"): @@ -113,6 +127,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -124,6 +139,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -138,6 +154,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -149,12 +166,17 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type == "bloom": return BLOOMSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "gpt_neox": @@ -163,6 +185,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -170,6 +193,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) else: @@ -177,6 +201,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -186,6 +211,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -195,6 +221,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -210,6 +237,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) raise NotImplementedError( @@ -221,6 +249,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) else: @@ -228,12 +257,17 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "opt": return OPTSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "t5": @@ -241,6 +275,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -253,11 +288,19 @@ def get_model( if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) auto_map = config_dict.get("auto_map", None) @@ -267,6 +310,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): @@ -274,6 +318,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 50b3b76a..101da207 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6d47c6eb..cbdf4808 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -454,11 +454,12 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 2c59f01e..417ccabb 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index e64af0c6..61004d8e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index a55f9118..12b862d7 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index ef202785..415ec2df 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 01e1c773..01e58bad 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 0abf0239..91877fa0 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 16cb48b7..d407b44a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,12 +22,13 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 2b1e4959..92bb135b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,11 +12,12 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index d0fd3070..a2b38737 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,11 +19,12 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 999b6637..9e5c21d1 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -504,11 +504,12 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index c89462fc..1b7073af 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 6cc5beeb..c375330a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -106,6 +106,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + dtype: Optional[str], trust_remote_code: bool, uds_path: Path, ): @@ -114,6 +115,7 @@ def serve( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + dtype: Optional[str] = None, trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" @@ -128,7 +130,9 @@ def serve( server_urls = [local_url] try: - model = get_model(model_id, revision, sharded, quantize, trust_remote_code) + model = get_model( + model_id, revision, sharded, quantize, dtype, trust_remote_code + ) except Exception: logger.exception("Error when initializing model") raise @@ -159,4 +163,6 @@ def serve( logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code)) + asyncio.run( + serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + ) From 2b53d71991e8fe975be41a82ffe3b52b0bcd40a3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 30 Jun 2023 23:09:20 +0200 Subject: [PATCH 04/11] fix(launcher): fix issue where launcher does not properly report shard failures (#522) --- launcher/src/main.rs | 58 ++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 51131f42..30abe88a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read}; use std::path::Path; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::TryRecvError; -use std::sync::Arc; -use std::sync::{mpsc, Mutex}; +use std::sync::{mpsc, Arc}; use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; @@ -274,7 +273,7 @@ struct Args { #[derive(Debug)] enum ShardStatus { Ready, - Failed((usize, String)), + Failed((usize, Option)), } #[allow(clippy::too_many_arguments)] @@ -296,7 +295,7 @@ fn shard_manager( watermark_delta: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, - shutdown: Arc>, + shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, ) { // Get UDS path @@ -433,20 +432,20 @@ fn shard_manager( } } status_sender - .send(ShardStatus::Failed((rank, err.to_string()))) + .send(ShardStatus::Failed((rank, Some(err.to_string())))) .unwrap(); return; } }; // Redirect STDOUT to the console - let shard_stdout = p.stdout.take().unwrap(); + let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); + let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); thread::spawn(move || { // Enter shard-manager tracing span - let stdout = BufReader::new(shard_stdout); let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); - for line in stdout.lines() { + for line in shard_stdout_reader.lines() { // Parse loguru logs if let Ok(log) = serde_json::from_str::(&line.unwrap()) { log.trace(); @@ -460,8 +459,22 @@ fn shard_manager( loop { // Process exited if let Some(exit_status) = p.poll() { - let mut err = String::new(); - p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); + // We read stderr in another thread as it seems that `read_to_string` can block + // indefinitely in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + let mut err = String::new(); + shard_stderr_reader.read_to_string(&mut err).unwrap(); + err_sender.send(err).unwrap_or(()); + }); + + let err = err_receiver + .recv_timeout(Duration::from_millis(100)) + .map_err(|err| { + tracing::error!("Unable to read shard {rank} error from stderr"); + err + }) + .ok(); if let ExitStatus::Signaled(signal) = exit_status { tracing::error!("Shard process was signaled to shutdown with signal {signal}"); @@ -474,7 +487,7 @@ fn shard_manager( } // We received a shutdown signal - if *shutdown.lock().unwrap() { + if shutdown.load(Ordering::SeqCst) { p.kill().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); tracing::info!("Shard {rank} terminated"); @@ -494,14 +507,11 @@ fn shard_manager( } } -fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receiver<()>) { +fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>) { tracing::info!("Shutting down shards"); // Update shutdown value to true // This will be picked up by the shard manager - { - let mut shutdown = shutdown.lock().unwrap(); - *shutdown = true; - } + shutdown.store(true, Ordering::SeqCst); // Wait for shards to shutdown // This will block till all shutdown_sender are dropped @@ -743,7 +753,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L fn spawn_shards( num_shard: usize, args: &Args, - shutdown: Arc>, + shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, status_receiver: &mpsc::Receiver, @@ -819,7 +829,10 @@ fn spawn_shards( sleep(Duration::from_millis(100)); } Ok(ShardStatus::Failed((rank, err))) => { - tracing::error!("Shard {} failed to start:\n{}", rank, err); + tracing::error!("Shard {rank} failed to start"); + if let Some(err) = err { + tracing::error!("{err}"); + } shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::ShardCannotStart); } @@ -835,7 +848,7 @@ fn spawn_shards( fn spawn_webserver( args: Args, - shutdown: Arc>, + shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, ) -> Result { // All shard started @@ -1002,7 +1015,7 @@ fn main() -> Result<(), LauncherError> { download_convert_model(&args, running.clone())?; // Shared shutdown bool - let shutdown = Arc::new(Mutex::new(false)); + let shutdown = Arc::new(AtomicBool::new(false)); // Shared shutdown channel // When shutting down, the main thread will wait for all senders to be dropped let (shutdown_sender, shutdown_receiver) = mpsc::channel(); @@ -1034,7 +1047,10 @@ fn main() -> Result<(), LauncherError> { while running.load(Ordering::SeqCst) { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed:\n{err}"); + tracing::error!("Shard {rank} failed to start"); + if let Some(err) = err { + tracing::error!("{err}"); + } exit_code = Err(LauncherError::ShardFailed); break; }; From e28a809004620c3f3a1cc28d4bbc0b4775b1328f Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Sat, 1 Jul 2023 19:25:41 +0200 Subject: [PATCH 05/11] v0.9.0 (#525) --- Cargo.lock | 505 ++++++++++++++++++++------------ Cargo.toml | 2 +- Dockerfile | 2 +- README.md | 2 +- aml/README.md | 15 - aml/deployment.yaml | 38 --- aml/endpoint.yaml | 3 - aml/model.yaml | 3 - docs/openapi.json | 31 +- launcher/src/main.rs | 8 +- router/Cargo.toml | 10 +- router/client/Cargo.toml | 4 +- router/grpc-metadata/Cargo.toml | 6 +- router/src/server.rs | 1 + rust-toolchain.toml | 2 +- server/pyproject.toml | 2 +- 16 files changed, 376 insertions(+), 258 deletions(-) delete mode 100644 aml/README.md delete mode 100644 aml/deployment.yaml delete mode 100644 aml/endpoint.yaml delete mode 100644 aml/model.yaml diff --git a/Cargo.lock b/Cargo.lock index 7a6f4ad2..b65045ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -10,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aes" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241" +checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2" dependencies = [ "cfg-if", "cipher", @@ -21,11 +30,11 @@ dependencies = [ [[package]] name = "ahash" -version = "0.7.6" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ - "getrandom", + "cfg-if", "once_cell", "version_check", ] @@ -65,15 +74,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d" +checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" [[package]] name = "anstyle-parse" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee" +checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" dependencies = [ "utf8parse", ] @@ -139,7 +148,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -150,7 +159,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -190,7 +199,7 @@ checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", "axum-core", - "bitflags", + "bitflags 1.3.2", "bytes", "futures-util", "http", @@ -240,11 +249,26 @@ dependencies = [ "axum", "futures", "http", - "opentelemetry", + "opentelemetry 0.18.0", "tower", "tower-http 0.3.5", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.18.0", +] + +[[package]] +name = "backtrace" +version = "0.3.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", ] [[package]] @@ -271,6 +295,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" + [[package]] name = "block-buffer" version = "0.10.4" @@ -380,9 +410,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.4" +version = "4.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80672091db20273a15cf9fdd4e47ed43b5091ec9841bf4c6145c9dfbbcae09ed" +checksum = "384e169cc618c613d5e3ca6404dda77a8685a63e08660dcc64abaf7da7cb0c7a" dependencies = [ "clap_builder", "clap_derive", @@ -391,13 +421,12 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.3.4" +version = "4.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1458a1df40e1e2afebb7ab60ce55c1fa8f431146205aa5f4887e0b111c27636" +checksum = "ef137bbe35aab78bdb468ccfba75a5f4d8321ae011d34063770780545176af2d" dependencies = [ "anstream", "anstyle", - "bitflags", "clap_lex", "strsim", ] @@ -411,7 +440,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -528,7 +557,7 @@ version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a84cda67535339806297f1b331d6dd6320470d2a0fe65381e79ee9e156dd3d13" dependencies = [ - "bitflags", + "bitflags 1.3.2", "crossterm_winapi", "libc", "mio", @@ -609,7 +638,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" dependencies = [ "cfg-if", - "hashbrown", + "hashbrown 0.12.3", "lock_api", "once_cell", "parking_lot_core", @@ -895,7 +924,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -956,10 +985,16 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" + [[package]] name = "glob" version = "0.3.1" @@ -970,17 +1005,17 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" name = "grpc-metadata" version = "0.1.0" dependencies = [ - "opentelemetry", - "tonic", + "opentelemetry 0.19.0", + "tonic 0.9.2", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.19.0", ] [[package]] name = "h2" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d357c7ae988e7d2182f7d7871d0b963962420b0678b0997ce7de72001aeab782" +checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" dependencies = [ "bytes", "fnv", @@ -1000,6 +1035,12 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" dependencies = [ "ahash", ] @@ -1010,15 +1051,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.3.1" @@ -1087,9 +1119,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.26" +version = "0.14.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" dependencies = [ "bytes", "futures-channel", @@ -1157,7 +1189,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", "serde", ] @@ -1209,26 +1241,25 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", "windows-sys 0.48.0", ] [[package]] name = "ipnet" -version = "2.7.2" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" +checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" +checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb" dependencies = [ - "hermit-abi 0.3.1", - "io-lifetimes", - "rustix", + "hermit-abi", + "rustix 0.38.1", "windows-sys 0.48.0", ] @@ -1291,9 +1322,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.146" +version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "libm" @@ -1307,6 +1338,12 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +[[package]] +name = "linux-raw-sys" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" + [[package]] name = "lock_api" version = "0.4.10" @@ -1324,10 +1361,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] -name = "mach" -version = "0.3.2" +name = "mach2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" +checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" dependencies = [ "libc", ] @@ -1386,28 +1423,27 @@ dependencies = [ [[package]] name = "metrics" -version = "0.20.1" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849" +checksum = "aa8ebbd1a9e57bbab77b9facae7f5136aea44c356943bf9a198f647da64285d6" dependencies = [ "ahash", "metrics-macros", - "portable-atomic 0.3.20", + "portable-atomic", ] [[package]] name = "metrics-exporter-prometheus" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70" +checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5" dependencies = [ + "base64 0.21.2", "hyper", "indexmap", "ipnet", "metrics", "metrics-util", - "parking_lot", - "portable-atomic 0.3.20", "quanta", "thiserror", "tokio", @@ -1416,28 +1452,26 @@ dependencies = [ [[package]] name = "metrics-macros" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3" +checksum = "ddece26afd34c31585c74a4db0630c376df271c285d682d1e55012197830b6df" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.22", ] [[package]] name = "metrics-util" -version = "0.14.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a" +checksum = "111cb375987443c3de8d503580b536f77dc8416d32db62d9456db5d93bd7ac47" dependencies = [ "crossbeam-epoch", "crossbeam-utils", - "hashbrown", + "hashbrown 0.13.2", "metrics", "num_cpus", - "parking_lot", - "portable-atomic 0.3.20", "quanta", "sketches-ddsketch", ] @@ -1481,7 +1515,7 @@ checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", "log", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys 0.48.0", ] @@ -1503,7 +1537,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -1520,7 +1554,7 @@ checksum = "e92b89ac3127251efde6f5a9586e5aae99468d06fcf9f133b377f58d5ed66446" dependencies = [ "async-trait", "awaitdrop", - "bitflags", + "bitflags 1.3.2", "bytes", "futures", "pin-project", @@ -1560,9 +1594,9 @@ dependencies = [ [[package]] name = "ngrok" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98ce3514eec7338e2d4663e3efb4429e08d8f3662996be4b9585350e7d8ad728" +checksum = "87e211f407b0a084f720823a00c956aeab2c15dfe7a61760d93227bbaf048026" dependencies = [ "arc-swap", "async-rustls", @@ -1595,7 +1629,7 @@ version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "libc", "static_assertions", @@ -1648,11 +1682,11 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.2.6", + "hermit-abi", "libc", ] @@ -1668,6 +1702,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "object" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.18.0" @@ -1680,7 +1723,7 @@ version = "6.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" dependencies = [ - "bitflags", + "bitflags 1.3.2", "libc", "once_cell", "onig_sys", @@ -1698,11 +1741,11 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.54" +version = "0.10.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69b3f656a17a6cbc115b5c7a40c616947d213ba182135b014d6051b73ab6f019" +checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "foreign-types", "libc", @@ -1719,7 +1762,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -1730,9 +1773,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.88" +version = "0.9.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2ce0f250f34a308dcfdbb351f511359857d4ed2134ba715a4eadd46e1ffd617" +checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" dependencies = [ "cc", "libc", @@ -1746,40 +1789,49 @@ version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" dependencies = [ - "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_api 0.18.0", + "opentelemetry_sdk 0.18.0", +] + +[[package]] +name = "opentelemetry" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f" +dependencies = [ + "opentelemetry_api 0.19.0", + "opentelemetry_sdk 0.19.0", ] [[package]] name = "opentelemetry-otlp" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde" +checksum = "8af72d59a4484654ea8eb183fea5ae4eb6a41d7ac3e3bae5f4d2a282a3a7d3ca" dependencies = [ "async-trait", "futures", "futures-util", "http", - "opentelemetry", + "opentelemetry 0.19.0", "opentelemetry-proto", "prost", "thiserror", "tokio", - "tonic", + "tonic 0.8.3", ] [[package]] name = "opentelemetry-proto" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28" +checksum = "045f8eea8c0fa19f7d48e7bc3128a39c2e5c533d5c61298c548dfefc1064474c" dependencies = [ "futures", "futures-util", - "opentelemetry", + "opentelemetry 0.19.0", "prost", - "tonic", - "tonic-build", + "tonic 0.8.3", ] [[package]] @@ -1798,6 +1850,22 @@ dependencies = [ "thiserror", ] +[[package]] +name = "opentelemetry_api" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2" +dependencies = [ + "fnv", + "futures-channel", + "futures-util", + "indexmap", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", +] + [[package]] name = "opentelemetry_sdk" version = "0.18.0" @@ -1812,7 +1880,29 @@ dependencies = [ "futures-executor", "futures-util", "once_cell", - "opentelemetry_api", + "opentelemetry_api 0.18.0", + "percent-encoding", + "rand", + "thiserror", + "tokio", + "tokio-stream", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1" +dependencies = [ + "async-trait", + "crossbeam-channel", + "dashmap", + "fnv", + "futures-channel", + "futures-executor", + "futures-util", + "once_cell", + "opentelemetry_api 0.19.0", "percent-encoding", "rand", "thiserror", @@ -1857,7 +1947,7 @@ dependencies = [ "libc", "redox_syscall 0.3.5", "smallvec", - "windows-targets 0.48.0", + "windows-targets 0.48.1", ] [[package]] @@ -1907,22 +1997,22 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" +checksum = "6e138fdd8263907a2b0e1b4e80b7e58c721126479b6e6eedfb1b402acea7b9bd" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" +checksum = "d1fef411b303e3e12d534fb6e7852de82da56edd937d895125821fb7c09436c7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -1943,15 +2033,6 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" -[[package]] -name = "portable-atomic" -version = "0.3.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e30165d31df606f5726b090ec7592c308a0eaf61721ff64c9a3018e344a8753e" -dependencies = [ - "portable-atomic 1.3.3", -] - [[package]] name = "portable-atomic" version = "1.3.3" @@ -2000,9 +2081,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.60" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" +checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" dependencies = [ "unicode-ident", ] @@ -2063,25 +2144,25 @@ dependencies = [ [[package]] name = "quanta" -version = "0.10.1" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" dependencies = [ "crossbeam-utils", "libc", - "mach", + "mach2", "once_cell", "raw-cpuid", - "wasi 0.10.2+wasi-snapshot-preview1", + "wasi", "web-sys", "winapi", ] [[package]] name = "quote" -version = "1.0.28" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" dependencies = [ "proc-macro2", ] @@ -2122,7 +2203,7 @@ version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcc0d032bccba900ee32151ec0265667535c230169f5a011154cdcd984e16829" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cassowary", "crossterm", "unicode-segmentation", @@ -2135,7 +2216,7 @@ version = "10.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -2177,7 +2258,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -2186,7 +2267,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -2286,9 +2367,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "6.7.0" +version = "6.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b73e721f488c353141288f223b599b4ae9303ecf3e62923f40a492f0634a4dc3" +checksum = "a36224c3276f8c4ebc8c20f158eca7ca4359c8db89991c4925132aaaf6702661" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -2297,28 +2378,34 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "6.6.0" +version = "6.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22ce362f5561923889196595504317a4372b84210e6e335da529a65ea5452b5" +checksum = "49b94b81e5b2c284684141a2fb9e2a31be90638caf040bf9afbc5a0416afe1ac" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", "shellexpand", - "syn 2.0.18", + "syn 2.0.22", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "7.5.0" +version = "7.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512b0ab6853f7e14e3c8754acb43d6f748bb9ced66aa5915a6553ac8213f7731" +checksum = "9d38ff6bf570dc3bb7100fce9f7b60c33fa71d80e88da3f2580df4ff2bdded74" dependencies = [ "sha2", "walkdir", ] +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustc_version" version = "0.4.0" @@ -2330,15 +2417,28 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.20" +version = "0.37.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" +checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", - "linux-raw-sys", + "linux-raw-sys 0.3.8", + "windows-sys 0.48.0", +] + +[[package]] +name = "rustix" +version = "0.38.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3" +dependencies = [ + "bitflags 2.3.3", + "errno", + "libc", + "linux-raw-sys 0.4.3", "windows-sys 0.48.0", ] @@ -2356,9 +2456,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ "base64 0.21.2", ] @@ -2415,7 +2515,7 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -2455,14 +2555,14 @@ checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] name = "serde_json" -version = "1.0.97" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a" +checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" dependencies = [ "itoa", "ryu", @@ -2668,9 +2768,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.18" +version = "2.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" +checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" dependencies = [ "proc-macro2", "quote", @@ -2685,9 +2785,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sysinfo" -version = "0.29.2" +version = "0.29.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9557d0845b86eea8182f7b10dff120214fb6cd9fd937b6f4917714e546a38695" +checksum = "5bcd0346f90b6bc83526c7b180039a8acd26a5c848cc556d457f6472eb148122" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2742,13 +2842,13 @@ dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix", + "rustix 0.37.21", "windows-sys 0.48.0", ] [[package]] name = "text-generation-benchmark" -version = "0.8.2" +version = "0.9.0" dependencies = [ "average", "clap", @@ -2768,7 +2868,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.8.2" +version = "0.9.0" dependencies = [ "futures", "grpc-metadata", @@ -2776,7 +2876,7 @@ dependencies = [ "prost-build", "thiserror", "tokio", - "tonic", + "tonic 0.9.2", "tonic-build", "tower", "tracing", @@ -2784,7 +2884,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.8.2" +version = "0.9.0" dependencies = [ "clap", "ctrlc", @@ -2800,7 +2900,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "0.8.2" +version = "0.9.0" dependencies = [ "async-stream", "axum", @@ -2812,7 +2912,7 @@ dependencies = [ "metrics-exporter-prometheus", "ngrok", "nohash-hasher", - "opentelemetry", + "opentelemetry 0.19.0", "opentelemetry-otlp", "rand", "reqwest", @@ -2822,9 +2922,9 @@ dependencies = [ "thiserror", "tokenizers", "tokio", - "tower-http 0.4.0", + "tower-http 0.4.1", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.19.0", "tracing-subscriber", "utoipa", "utoipa-swagger-ui", @@ -2848,7 +2948,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -2941,11 +3041,12 @@ dependencies = [ [[package]] name = "tokio" -version = "1.28.2" +version = "1.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" dependencies = [ "autocfg", + "backtrace", "bytes", "libc", "mio", @@ -2976,7 +3077,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -3059,10 +3160,38 @@ dependencies = [ ] [[package]] -name = "tonic-build" -version = "0.8.4" +name = "tonic" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4" +checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" +dependencies = [ + "async-trait", + "axum", + "base64 0.21.2", + "bytes", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07" dependencies = [ "prettyplease", "proc-macro2", @@ -3097,7 +3226,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" dependencies = [ - "bitflags", + "bitflags 1.3.2", "bytes", "futures-core", "futures-util", @@ -3112,11 +3241,11 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658" +checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c" dependencies = [ - "bitflags", + "bitflags 2.3.3", "bytes", "futures-core", "futures-util", @@ -3155,13 +3284,13 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" +checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -3202,7 +3331,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" dependencies = [ "once_cell", - "opentelemetry", + "opentelemetry 0.18.0", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00a39dcf9bfc1742fa4d6215253b33a6e474be78275884c216fc2a06267b3600" +dependencies = [ + "once_cell", + "opentelemetry 0.19.0", "tracing", "tracing-core", "tracing-log", @@ -3326,6 +3469,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" + [[package]] name = "utf8parse" version = "0.2.1" @@ -3353,7 +3502,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", ] [[package]] @@ -3422,12 +3571,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3455,7 +3598,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", "wasm-bindgen-shared", ] @@ -3489,7 +3632,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.22", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3592,7 +3735,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.48.1", ] [[package]] @@ -3612,9 +3755,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.48.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f" dependencies = [ "windows_aarch64_gnullvm 0.48.0", "windows_aarch64_msvc 0.48.0", diff --git a/Cargo.toml b/Cargo.toml index b28286fa..ba7a920d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.8.2" +version = "0.9.0" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/Dockerfile b/Dockerfile index 1a969383..66e0091d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/README.md b/README.md index b74d2617..d31c176b 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ model=bigscience/bloom-560m num_shard=2 volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.8 --model-id $model --num-shard $num_shard +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9 --model-id $model --num-shard $num_shard ``` **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. diff --git a/aml/README.md b/aml/README.md deleted file mode 100644 index 8e78b0ab..00000000 --- a/aml/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# Azure ML endpoint - -## Create all resources - -```shell -az ml model create -f model.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace -az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace -az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace -``` - -## Update deployment - -```shell -az ml online-deployment update -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace -``` \ No newline at end of file diff --git a/aml/deployment.yaml b/aml/deployment.yaml deleted file mode 100644 index 320eba24..00000000 --- a/aml/deployment.yaml +++ /dev/null @@ -1,38 +0,0 @@ -$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json -name: bloom-deployment -endpoint_name: bloom-inference -model: azureml:bloom-safetensors:1 -model_mount_path: /var/azureml-model -environment_variables: - WEIGHTS_CACHE_OVERRIDE: /var/azureml-model/bloom-safetensors - MODEL_ID: bigscience/bloom - NUM_SHARD: 8 -environment: - image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2.0 - inference_config: - liveness_route: - port: 80 - path: /health - readiness_route: - port: 80 - path: /health - scoring_route: - port: 80 - path: /generate -instance_type: Standard_ND96amsr_A100_v4 -request_settings: - request_timeout_ms: 90000 - max_concurrent_requests_per_instance: 256 -liveness_probe: - initial_delay: 600 - timeout: 90 - period: 120 - success_threshold: 1 - failure_threshold: 5 -readiness_probe: - initial_delay: 600 - timeout: 90 - period: 120 - success_threshold: 1 - failure_threshold: 5 -instance_count: 1 diff --git a/aml/endpoint.yaml b/aml/endpoint.yaml deleted file mode 100644 index f2f01d5e..00000000 --- a/aml/endpoint.yaml +++ /dev/null @@ -1,3 +0,0 @@ -$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json -name: bloom-inference -auth_mode: key diff --git a/aml/model.yaml b/aml/model.yaml deleted file mode 100644 index bfcdd33f..00000000 --- a/aml/model.yaml +++ /dev/null @@ -1,3 +0,0 @@ -$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json -name: bloom-safetensors -path: /data/bloom-safetensors diff --git a/docs/openapi.json b/docs/openapi.json index 8c652946..b91729d0 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.8.2" + "version": "0.9.0" }, "paths": { "/": { @@ -270,6 +270,35 @@ } } }, + "/health": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Health check method", + "description": "Health check method", + "operationId": "health", + "responses": { + "200": { + "description": "Everything is working fine" + }, + "503": { + "description": "Text generation inference is down", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "unhealthy", + "error_type": "healthcheck" + } + } + } + } + } + } + }, "/info": { "get": { "tags": [ diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 30abe88a..5b5cb45e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1040,14 +1040,18 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } - let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?; + let mut webserver = + spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { + shutdown_shards(shutdown.clone(), &shutdown_receiver); + err + })?; // Default exit code let mut exit_code = Ok(()); while running.load(Ordering::SeqCst) { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed to start"); + tracing::error!("Shard {rank} crashed"); if let Some(err) = err { tracing::error!("{err}"); } diff --git a/router/Cargo.toml b/router/Cargo.toml index c1e665b1..10396826 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,11 +22,11 @@ text-generation-client = { path = "client" } clap = { version = "4.1.4", features = ["derive", "env"] } flume = "0.10.14" futures = "0.3.26" -metrics = "0.20.1" -metrics-exporter-prometheus = { version = "0.11.0", features = [] } +metrics = "0.21.0" +metrics-exporter-prometheus = { version = "0.12.1", features = [] } nohash-hasher = "0.2.0" -opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } -opentelemetry-otlp = "0.11.0" +opentelemetry = { version = "0.19.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.12.0" rand = "0.8.5" reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" @@ -36,7 +36,7 @@ tokenizers = "0.13.3" tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tower-http = { version = "0.4.0", features = ["cors"] } tracing = "0.1.37" -tracing-opentelemetry = "0.18.0" +tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index db7245e0..43f444e6 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -11,10 +11,10 @@ grpc-metadata = { path = "../grpc-metadata" } prost = "^0.11" thiserror = "^1.0" tokio = { version = "^1.25", features = ["sync"] } -tonic = "^0.8" +tonic = "^0.9" tower = "^0.4" tracing = "^0.1" [build-dependencies] -tonic-build = "0.8.4" +tonic-build = "0.9.2" prost-build = "0.11.6" diff --git a/router/grpc-metadata/Cargo.toml b/router/grpc-metadata/Cargo.toml index 311092e3..9e01f527 100644 --- a/router/grpc-metadata/Cargo.toml +++ b/router/grpc-metadata/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -opentelemetry = "0.18.0" -tonic = "^0.8" +opentelemetry = "^0.19" +tonic = "^0.9" tracing = "^0.1" -tracing-opentelemetry = "0.18.0" +tracing-opentelemetry = "^0.19" diff --git a/router/src/server.rs b/router/src/server.rs index ee96ead6..54418f84 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -532,6 +532,7 @@ pub async fn run( #[derive(OpenApi)] #[openapi( paths( + health, get_model_info, compat_generate, generate, diff --git a/rust-toolchain.toml b/rust-toolchain.toml index d6ecf0c4..2db1883c 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.69.0" +channel = "1.70.0" components = ["rustfmt", "clippy"] \ No newline at end of file diff --git a/server/pyproject.toml b/server/pyproject.toml index f0ec25eb..294bcfc0 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "0.8.2" +version = "0.9.0" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] From 1da07e85aae8ce417dda3effd516691394dc31a1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 3 Jul 2023 13:01:46 +0200 Subject: [PATCH 06/11] feat(server): Add Non flash MPT. (#514) # What does this PR do? This adds a non flash version of MPT. Flash is harder because we need to create a bias ready cuda kernel of flash attention. Fixes https://github.com/huggingface/text-generation-inference/issues/361 Fixes https://github.com/huggingface/text-generation-inference/issues/491 Fixes https://github.com/huggingface/text-generation-inference/issues/290 --- .../__snapshots__/test_mpt/test_mpt.json | 140 ++ .../__snapshots__/test_mpt/test_mpt_load.json | 562 ++++++++ integration-tests/models/test_mpt.py | 48 + server/poetry.lock | 13 +- server/pyproject.toml | 1 + server/requirements.txt | 1 + .../text_generation_server/models/__init__.py | 5 + .../models/custom_modeling/mpt_modeling.py | 1140 +++++++++++++++++ server/text_generation_server/models/mpt.py | 90 ++ server/text_generation_server/utils/layers.py | 12 + 10 files changed, 2011 insertions(+), 1 deletion(-) create mode 100644 integration-tests/models/__snapshots__/test_mpt/test_mpt.json create mode 100644 integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json create mode 100644 integration-tests/models/test_mpt.py create mode 100644 server/text_generation_server/models/custom_modeling/mpt_modeling.py create mode 100644 server/text_generation_server/models/mpt.py diff --git a/integration-tests/models/__snapshots__/test_mpt/test_mpt.json b/integration-tests/models/__snapshots__/test_mpt/test_mpt.json new file mode 100644 index 00000000..abbbf03c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mpt/test_mpt.json @@ -0,0 +1,140 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5117188, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.96875, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.953125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.94189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5830078, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3105469, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.3215332, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5566406, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.6074219, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.6923828, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5263672, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.8544922, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6118164, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.055877686, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0537109, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.0115737915, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9111328, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4589844, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.4853516, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021636963, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" +} diff --git a/integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json b/integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json new file mode 100644 index 00000000..e3bc57ed --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json @@ -0,0 +1,562 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5117188, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.96875, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.953125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.94189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5830078, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3183594, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.32617188, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.6015625, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.67822266, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5395508, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.8623047, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6020508, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.0552063, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0742188, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011405945, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9165039, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4501953, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.4960938, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.02116394, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.984375, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.96875, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.93359375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5800781, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3242188, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.31835938, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5644531, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.5957031, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.68603516, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5258789, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.859375, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6166992, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.056762695, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0703125, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011428833, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9213867, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4726562, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.5039062, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021652222, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.984375, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.96875, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.93359375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5800781, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3242188, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.31835938, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5644531, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.5957031, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.68603516, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5258789, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.859375, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6166992, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.056762695, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0703125, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011428833, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9213867, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4726562, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.5039062, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021652222, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.984375, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.96875, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.93359375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5800781, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3242188, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.31835938, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5644531, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.5957031, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.68603516, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5258789, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.859375, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6166992, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.056762695, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0703125, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011428833, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9213867, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4726562, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.5039062, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021652222, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + } +] diff --git a/integration-tests/models/test_mpt.py b/integration-tests/models/test_mpt.py new file mode 100644 index 00000000..d58a8c5a --- /dev/null +++ b/integration-tests/models/test_mpt.py @@ -0,0 +1,48 @@ +import pytest + + +@pytest.fixture(scope="module") +def mpt_sharded_handle(launcher): + with launcher("mosaicml/mpt-7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def mpt_sharded(mpt_sharded_handle): + await mpt_sharded_handle.health(300) + return mpt_sharded_handle.client + + +@pytest.mark.asyncio +async def test_mpt(mpt_sharded, response_snapshot): + response = await mpt_sharded.generate( + "What is Deep Learning?", + max_new_tokens=17, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 17 + assert ( + response.generated_text + == " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mpt_load(mpt_sharded, generate_load, response_snapshot): + responses = await generate_load( + mpt_sharded, + "What is Deep Learning?", + max_new_tokens=17, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert ( + responses[0].generated_text + == " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + ) + + assert responses == response_snapshot diff --git a/server/poetry.lock b/server/poetry.lock index 9a6900bc..7c126e80 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -187,6 +187,17 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "einops" +version = "0.6.1" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.7" +files = [ + {file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"}, + {file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"}, +] + [[package]] name = "exceptiongroup" version = "1.1.1" @@ -1586,4 +1597,4 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "54ecacb32d699cb1298c237c4661c1b707f119cf2c27bd54bad7a1ea2ffb8b10" +content-hash = "3174a211d30bed5990ed5f8418416c951bb6c585153fb51b62809baa89ef07d0" diff --git a/server/pyproject.toml b/server/pyproject.toml index 294bcfc0..bbf5836d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -27,6 +27,7 @@ sentencepiece = "^0.1.97" tokenizers = "0.13.3" huggingface-hub = "^0.14.1" transformers = "^4.29.2" +einops = "^0.6.1" [tool.poetry.extras] accelerate = ["accelerate"] diff --git a/server/requirements.txt b/server/requirements.txt index a9bd441c..92693bbd 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -4,6 +4,7 @@ charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" +einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0" filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0" googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e45e198a..fd97f8b1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -10,6 +10,7 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded +from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.rw import RW from text_generation_server.models.opt import OPTSharded @@ -178,6 +179,10 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == "mpt": + return MPTSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) elif model_type == "gpt_neox": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py new file mode 100644 index 00000000..5ea204e1 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -0,0 +1,1140 @@ +"""A simple, flexible implementation of a GPT model. + +Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py +""" +import math +import os +import warnings +from typing import List, Optional, Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from einops import rearrange +from packaging import version +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelHead, + get_linear, +) + +EPS = 1e-5 + + +def load_col(config, prefix, weights, bias): + assert bias == False, NotImplementedError + assert config.quantize != "gptq", NotImplementedError + slice_ = weights._get_slice(f"{prefix}.weight") + rank = weights.process_group.rank() + size = weights.process_group.size() + + h3, h = slice_.get_shape() + block_size = h // size + + q_part = slice_[rank * block_size : (rank + 1) * block_size] + k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] + v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] + + weight = torch.cat([q_part, k_part, v_part], dim=0) + if weight.dtype != torch.int32: + weight = weight.to(dtype=weights.dtype) + weight = weight.to(device=weights.device) + bias = None + linear = get_linear(weight, bias, config.quantize) + return TensorParallelColumnLinear(linear) + + +def _reset_is_causal( + num_query_tokens: int, num_key_tokens: int, original_is_causal: bool +): + if original_is_causal and num_query_tokens != num_key_tokens: + if num_query_tokens != 1: + raise NotImplementedError( + "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." + ) + else: + return False + return original_is_causal + + +def scaled_multihead_dot_product_attention( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) + kv_n_heads = 1 if multiquery else n_heads + k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) + v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) + if past_key_value is not None: + if len(past_key_value) != 0: + k = torch.cat([past_key_value[0], k], dim=3) + v = torch.cat([past_key_value[1], v], dim=2) + past_key_value = (k, v) + (b, _, s_q, d) = q.shape + s_k = k.size(-1) + attn_weight = q.matmul(k) * softmax_scale + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - s_q) + _s_k = max(0, attn_bias.size(3) - s_k) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if ( + attn_bias.size(-1) != 1 + and attn_bias.size(-1) != s_k + or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) + ): + raise RuntimeError( + f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." + ) + attn_weight = attn_weight + attn_bias + min_val = torch.finfo(q.dtype).min + if key_padding_mask is not None: + if attn_bias is not None: + warnings.warn( + "Propogating key_padding_mask to the attention module " + + "and applying it within the attention module can cause " + + "unneccessary computation/memory usage. Consider integrating " + + "into attn_bias once and passing that to each attention " + + "module instead." + ) + attn_weight = attn_weight.masked_fill( + ~key_padding_mask.view((b, 1, 1, s_k)), min_val + ) + if is_causal and (not q.size(2) == 1): + s = max(s_q, s_k) + causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) + causal_mask = causal_mask.tril() + causal_mask = causal_mask.to(torch.bool) + causal_mask = ~causal_mask + causal_mask = causal_mask[-s_q:, -s_k:] + attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p: + attn_weight = torch.nn.functional.dropout( + attn_weight, p=dropout_p, training=training, inplace=True + ) + out = attn_weight.to(v.dtype).matmul(v) + out = rearrange(out, "b h s d -> b s (h d)") + if needs_weights: + return (out, attn_weight, past_key_value) + return (out, None, past_key_value) + + +def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): + for tensor in tensors: + if tensor.dtype not in valid_dtypes: + raise TypeError( + f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." + ) + if not tensor.is_cuda: + raise TypeError( + f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." + ) + + +def flash_attn_fn( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + try: + from flash_attn import bert_padding, flash_attn_interface + except: + raise RuntimeError("Please install flash-attn==1.0.3.post0") + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - query.size(1)) + _s_k = max(0, attn_bias.size(3) - key.size(1)) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if attn_bias is not None: + raise NotImplementedError(f"attn_bias not implemented for flash attn.") + (batch_size, seqlen) = query.shape[:2] + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1) :] + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + query, query_padding_mask + ) + query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + key, key_padding_mask + ) + key_unpad = rearrange( + key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads + ) + (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) + value_unpad = rearrange( + value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads + ) + if multiquery: + key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) + value_unpad = value_unpad.expand( + value_unpad.size(0), n_heads, value_unpad.size(-1) + ) + dropout_p = dropout_p if training else 0.0 + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + output_unpad = flash_attn_interface.flash_attn_unpadded_func( + query_unpad, + key_unpad, + value_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights, + ) + output = bert_padding.pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen + ) + return (output, None, past_key_value) + + +def triton_flash_attn_fn( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + try: + from .flash_attn_triton import flash_attn_func + except: + _installed = False + if version.parse(torch.__version__) < version.parse("2.0.0"): + _installed = True + try: + from flash_attn.flash_attn_triton import flash_attn_func + except: + _installed = False + if not _installed: + raise RuntimeError( + "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." + ) + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - query.size(1)) + _s_k = max(0, attn_bias.size(3) - key.size(1)) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if dropout_p: + raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") + if needs_weights: + raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") + if key_padding_mask is not None: + warnings.warn( + "Propagating key_padding_mask to the attention module " + + "and applying it within the attention module can cause " + + "unnecessary computation/memory usage. Consider integrating " + + "into attn_bias once and passing that to each attention " + + "module instead." + ) + (b_size, s_k) = key_padding_mask.shape[:2] + if attn_bias is None: + attn_bias = query.new_zeros(b_size, 1, 1, s_k) + attn_bias = attn_bias.masked_fill( + ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min + ) + query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) + key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) + value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) + if multiquery: + key = key.expand(*key.shape[:2], n_heads, key.size(-1)) + value = value.expand(*value.shape[:2], n_heads, value.size(-1)) + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + attn_output = flash_attn_func( + query, key, value, attn_bias, reset_is_causal, softmax_scale + ) + output = attn_output.view(*attn_output.shape[:2], -1) + return (output, None, past_key_value) + + +class MultiheadAttention(nn.Module): + """Multi-head self attention. + + Using torch or triton attention implemetation enables user to also use + additive bias. + """ + + def __init__( + self, + config, + prefix, + weights, + ): + super().__init__() + attn_impl = config.attn_config["attn_impl"] + self.attn_impl = config.attn_config["attn_impl"] + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.d_model = config.d_model + d_model = config.d_model + self.n_heads = config.n_heads + self.softmax_scale = config.attn_config["softmax_scale"] + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) + self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.n_heads = self.n_heads // weights.process_group.size() + self.Wqkv = load_col( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) + if self.qk_ln: + raise NotImplementedError("qk_ln is not supported") + if self.attn_impl == "flash": + self.attn_fn = flash_attn_fn + elif self.attn_impl == "triton": + self.attn_fn = triton_flash_attn_fn + elif self.attn_impl == "torch": + self.attn_fn = scaled_multihead_dot_product_attention + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) + + def forward( + self, + x, + past_key_value=None, + attn_bias=None, + attention_mask=None, + is_causal=True, + needs_weights=False, + ): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.chunk(3, dim=2) + + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + (context, attn_weights, past_key_value) = self.attn_fn( + query, + key, + value, + self.n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + ) + out = self.out_proj(context) + return (out, attn_weights, past_key_value) + + +class MultiQueryAttention(nn.Module): + """Multi-Query self attention. + + Using torch or triton attention implemetation enables user to also use + additive bias. + """ + + def __init__(self, config, prefix, weights): + super().__init__() + attn_impl = config.attn_config["attn_impl"] + self.attn_impl = config.attn_config["attn_impl"] + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.d_model = config.d_model + d_model = config.d_model + self.n_heads = config.n_heads + self.softmax_scale = config.attn_config["softmax_scale"] + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.head_dim) + self.attn_dropout_p = config.attn_config["attn_pdrop"] + # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) + self.Wqkv = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) + fuse_splits = (d_model, d_model + self.head_dim) + if self.qk_ln: + raise NotImplementedError("qk_ln not supported") + if self.attn_impl == "flash": + self.attn_fn = flash_attn_fn + elif self.attn_impl == "triton": + self.attn_fn = triton_flash_attn_fn + if verbose: + warnings.warn( + "While `attn_impl: triton` can be faster than `attn_impl: flash` " + + "it uses more memory. When training larger models this can trigger " + + "alloc retries which hurts performance. If encountered, we recommend " + + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." + ) + elif self.attn_impl == "torch": + self.attn_fn = scaled_multihead_dot_product_attention + if torch.cuda.is_available() and verbose: + warnings.warn( + "Using `attn_impl: torch`. If your model does not use `alibi` or " + + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + + "we recommend using `attn_impl: triton`." + ) + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) + # self.out_proj._is_residual = True + + def forward( + self, + x, + past_key_value=None, + attn_bias=None, + attention_mask=None, + is_causal=True, + needs_weights=False, + ): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.split( + [self.d_model, self.head_dim, self.head_dim], dim=2 + ) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + (context, attn_weights, past_key_value) = self.attn_fn( + query, + key, + value, + self.n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + multiquery=True, + ) + return (self.out_proj(context), attn_weights, past_key_value) + + +def attn_bias_shape( + attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id +): + if attn_impl == "flash": + return None + elif attn_impl in ["torch", "triton"]: + if alibi: + if (prefix_lm or not causal) or use_sequence_id: + return (1, n_heads, seq_len, seq_len) + return (1, n_heads, 1, seq_len) + elif prefix_lm or use_sequence_id: + return (1, 1, seq_len, seq_len) + return None + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + + +def build_attn_bias( + attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 +): + if attn_impl == "flash": + return None + elif attn_impl in ["torch", "triton"]: + if alibi: + (device, dtype) = (attn_bias.device, attn_bias.dtype) + attn_bias = attn_bias.add( + build_alibi_bias( + n_heads, + seq_len, + full=not causal, + alibi_bias_max=alibi_bias_max, + device=device, + dtype=dtype, + ) + ) + return attn_bias + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + + +def gen_slopes(n_heads, alibi_bias_max=8, device=None): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) + m = m.mul(alibi_bias_max / _n_heads) + slopes = 1.0 / torch.pow(2, m) + if _n_heads != n_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] + return slopes.view(1, n_heads, 1, 1) + + +def build_alibi_bias( + n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None +): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( + 1, 1, 1, seq_len + ) + if full: + alibi_bias = alibi_bias - torch.arange( + 1 - seq_len, 1, dtype=torch.int32, device=device + ).view(1, 1, seq_len, 1) + alibi_bias = alibi_bias.abs().mul(-1) + slopes = gen_slopes(n_heads, alibi_bias_max, device=device) + alibi_bias = alibi_bias * slopes + return alibi_bias.to(dtype=dtype) + + +ATTN_CLASS_REGISTRY = { + "multihead_attention": MultiheadAttention, + "multiquery_attention": MultiQueryAttention, +} + +"""GPT Blocks used for the GPT Model.""" + + +class MPTMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) + self.up_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias + ) + self.act = nn.GELU(approximate="none") + # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=not config.no_bias, + ) + # self.down_proj._is_residual = True + + def forward(self, x): + return self.down_proj(self.act(self.up_proj(x))) + + +class MPTBlock(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.prefix = prefix + if config.attn_config["attn_type"] != "multihead_attention": + raise NotImplementedError( + f"""Not implemented attn {config.attn_config["attn_type"]}""" + ) + resid_pdrop = config.resid_pdrop + self.norm_1 = nn.LayerNorm.load_no_bias( + prefix=f"{prefix}.norm_1", weights=weights, eps=EPS + ) + self.norm_2 = nn.LayerNorm.load_no_bias( + prefix=f"{prefix}.norm_2", weights=weights, eps=EPS + ) + self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) + self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) + self.resid_ffn_dropout = nn.Dropout(resid_pdrop) + + def forward( + self, + x: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_bias: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.ByteTensor] = None, + is_causal: bool = True, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + a = self.norm_1(x) + (b, attn_weights, past_key_value) = self.attn( + a, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=is_causal, + ) + x = x + self.resid_attn_dropout(b) + m = self.norm_2(x) + n = self.ffn(m) + x = x + self.resid_ffn_dropout(n) + return (x, attn_weights, past_key_value) + + +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + + +class LPLayerNorm(torch.nn.LayerNorm): + def __init__( + self, + normalized_shape, + eps=1e-05, + elementwise_affine=True, + device=None, + dtype=None, + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + device=device, + dtype=dtype, + ) + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + downcast_bias = ( + _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + ) + with torch.autocast(enabled=False, device_type=module_device.type): + return torch.nn.functional.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) + + +def rms_norm(x, weight=None, eps=1e-05): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + if weight is not None: + return output * weight + return output + + +class RMSNorm(torch.nn.Module): + def __init__( + self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None + ): + super().__init__() + self.eps = eps + if weight: + self.weight = torch.nn.Parameter( + torch.ones(normalized_shape, dtype=dtype, device=device) + ) + else: + self.register_parameter("weight", None) + + def forward(self, x): + return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) + + +class LPRMSNorm(RMSNorm): + def __init__( + self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + weight=weight, + dtype=dtype, + device=device, + ) + + def forward(self, x): + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + with torch.autocast(enabled=False, device_type=x.device.type): + return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) + + +NORM_CLASS_REGISTRY = { + "layernorm": torch.nn.LayerNorm, + "low_precision_layernorm": LPLayerNorm, + "rmsnorm": RMSNorm, + "low_precision_rmsnorm": LPRMSNorm, +} + +Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +class MPTPreTrainedModel(PreTrainedModel): + base_model_prefix = "model" + _no_split_modules = ["MPTBlock"] + + +class MPTModel(MPTPreTrainedModel): + def __init__(self, config, weights): + # config._validate_config() + super().__init__(config) + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.n_heads = config.n_heads + self.attn_impl = config.attn_config["attn_impl"] + self.prefix_lm = config.attn_config["prefix_lm"] + self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] + self.alibi = config.attn_config["alibi"] + self.alibi_bias_max = config.attn_config["alibi_bias_max"] + if config.init_device == "mixed": + if dist.get_local_rank() == 0: + config.init_device = "cpu" + else: + config.init_device = "meta" + if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): + norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) + raise NotImplementedError( + f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." + ) + if config.norm_type.lower() != "low_precision_layernorm": + raise NotImplementedError( + f"Requested norm type ({config.norm_type}) is not implemented within this repo." + ) + + self.wte = TensorParallelEmbedding("transformer.wte", weights) + if not self.alibi: + # self.wpe = torch.nn.Embedding( + # config.max_seq_len, config.d_model, device=config.init_device + # ) + raise RuntimeError("no alibi no supported") + self.blocks = nn.ModuleList( + [ + MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) + for i in range(config.n_layers) + ] + ) + self.norm_f = nn.LayerNorm.load_no_bias( + prefix="transformer.norm_f", weights=weights, eps=EPS + ) + self.is_causal = not self.prefix_lm + self._attn_bias_initialized = False + self.attn_bias = None + self.attn_bias_shape = attn_bias_shape( + self.attn_impl, + config.n_heads, + config.max_seq_len, + self.alibi, + prefix_lm=self.prefix_lm, + causal=self.is_causal, + use_sequence_id=self.attn_uses_sequence_id, + ) + if config.no_bias: + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + if config.verbose: + warnings.warn(f"Removing bias ({module.bias}) from {module}.") + module.register_parameter("bias", None) + if config.verbose and config.verbose > 2: + print(self) + if "verbose" not in self.config.init_config: + self.config.init_config["verbose"] = self.config.verbose + if self.config.init_config["verbose"] > 1: + init_fn_name = self.config.init_config["name"] + warnings.warn(f"Using {init_fn_name} initialization.") + + @torch.no_grad() + def _attn_bias( + self, + device, + dtype, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + ): + if not self._attn_bias_initialized: + if self.attn_bias_shape: + self.attn_bias = torch.zeros( + self.attn_bias_shape, device=device, dtype=dtype + ) + self.attn_bias = build_attn_bias( + self.attn_impl, + self.attn_bias, + self.config.n_heads, + self.config.max_seq_len, + causal=self.is_causal, + alibi=self.alibi, + alibi_bias_max=self.alibi_bias_max, + ) + assert self.n_heads % self.world_size == 0 + block_size = self.n_heads // self.world_size + self.attn_bias = self.attn_bias[ + :, self.rank * block_size : (self.rank + 1) * block_size + ] + self._attn_bias_initialized = True + if self.attn_impl == "flash": + return (self.attn_bias, attention_mask) + if self.attn_bias is not None: + self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) + attn_bias = self.attn_bias + if self.prefix_lm: + assert isinstance(attn_bias, torch.Tensor) + assert isinstance(prefix_mask, torch.Tensor) + attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) + if self.attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) + attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + if attention_mask is not None: + s_k = attention_mask.shape[-1] + if attn_bias is None: + attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) + else: + _s_k = max(0, attn_bias.size(-1) - s_k) + attn_bias = attn_bias[:, :, :, _s_k:] + if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: + raise ValueError( + f"attention_mask shape={attention_mask.shape} " + + f"and prefix_mask shape={prefix_mask.shape} are not equal." + ) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill( + ~attention_mask.view(-1, 1, 1, s_k), min_val + ) + return (attn_bias, None) + + def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): + (s_k, s_q) = attn_bias.shape[-2:] + if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: + raise ValueError( + "attn_bias does not match the expected shape. " + + f"The last two dimensions should both be {self.config.max_length} " + + f"but are {s_k} and {s_q}." + ) + seq_len = prefix_mask.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError( + f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" + ) + attn_bias = attn_bias[..., :seq_len, :seq_len] + causal = torch.tril( + torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) + ).view(1, 1, seq_len, seq_len) + prefix = prefix_mask.view(-1, 1, 1, seq_len) + cannot_attend = ~torch.logical_or(causal, prefix.bool()) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def _apply_sequence_id( + self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor + ): + seq_len = sequence_id.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError( + f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" + ) + attn_bias = attn_bias[..., :seq_len, :seq_len] + cannot_attend = torch.logical_not( + torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) + ).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if attention_mask is not None: + attention_mask = attention_mask.bool() + if prefix_mask is not None: + prefix_mask = prefix_mask.bool() + if not return_dict: + raise NotImplementedError( + "return_dict False is not implemented yet for MPT" + ) + if output_attentions: + if self.attn_impl != "torch": + raise NotImplementedError( + "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." + ) + if ( + attention_mask is not None + and attention_mask[:, 0].sum() != attention_mask.shape[0] + and self.training + ): + raise NotImplementedError( + "MPT does not support training with left padding." + ) + if self.prefix_lm and prefix_mask is None: + raise ValueError( + "prefix_mask is a required argument when MPT is configured with prefix_lm=True." + ) + if self.training: + if self.attn_uses_sequence_id and sequence_id is None: + raise ValueError( + "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + + "and the model is in train mode." + ) + elif self.attn_uses_sequence_id is False and sequence_id is not None: + warnings.warn( + "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " + + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." + ) + S = input_ids.size(1) + assert ( + S <= self.config.max_seq_len + ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" + tok_emb = self.wte(input_ids) + if self.alibi: + x = tok_emb + else: + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f"past_key_values must provide a past_key_value for each attention " + + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." + ) + past_position = past_key_values[0][0].size(1) + if self.attn_impl == "torch": + past_position = past_key_values[0][0].size(3) + if S + past_position > self.config.max_seq_len: + raise ValueError( + f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." + ) + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + pos = torch.clamp( + pos + - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ + :, past_position: + ], + min=0, + ) + pos_emb = self.wpe(pos) + x = tok_emb + pos_emb + (attn_bias, attention_mask) = self._attn_bias( + device=x.device, + dtype=torch.float32, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + ) + if use_cache and past_key_values is None: + past_key_values = [() for _ in range(self.config.n_layers)] + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for b_idx, block in enumerate(self.blocks): + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + past_key_value = ( + past_key_values[b_idx] if past_key_values is not None else None + ) + (x, attn_weights, past_key_value) = block( + x, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=self.is_causal, + ) + if past_key_values is not None: + past_key_values[b_idx] = past_key_value + if output_attentions: + assert all_self_attns is not None + all_self_attns = all_self_attns + (attn_weights,) + x = self.norm_f(x) + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + return BaseModelOutputWithPast( + last_hidden_state=x, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MPTForCausalLM(MPTPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + if not config.tie_word_embeddings: + raise ValueError("MPTForCausalLM only supports tied word embeddings") + self.transformer = MPTModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, prefix="transformer.wte", weights=weights + ) + self.logit_scale = None + if config.logit_scale is not None: + logit_scale = config.logit_scale + if isinstance(logit_scale, str): + if logit_scale == "inv_sqrt_d_model": + logit_scale = 1 / math.sqrt(config.d_model) + else: + raise ValueError( + f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + ) + self.logit_scale = logit_scale + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + return_dict=return_dict, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.logit_scale is not None: + if self.logit_scale == 0: + warnings.warn( + f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." + ) + logits *= self.logit_scale + loss = None + if labels is not None: + labels = torch.roll(labels, shifts=-1) + labels[:, -1] = -100 + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) + ) + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + if inputs_embeds is not None: + raise NotImplementedError("inputs_embeds is not implemented for MPT yet") + attention_mask = kwargs["attention_mask"].bool() + if attention_mask[:, -1].sum() != attention_mask.shape[0]: + raise NotImplementedError( + "MPT does not support generation with right padding." + ) + if self.transformer.attn_uses_sequence_id and self.training: + sequence_id = torch.zeros_like(input_ids[:1]) + else: + sequence_id = None + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if self.transformer.prefix_lm: + prefix_mask = torch.ones_like(attention_mask) + if kwargs.get("use_cache") == False: + raise NotImplementedError( + "MPT with prefix_lm=True does not support use_cache=False." + ) + else: + prefix_mask = None + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prefix_mask": prefix_mask, + "sequence_id": sequence_id, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", True), + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + """Used by HuggingFace generate when using beam search with kv-caching. + + See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 + for an example in transformers. + """ + reordered_past = [] + for layer_past in past_key_values: + reordered_past += [ + tuple( + (past_state.index_select(0, beam_idx) for past_state in layer_past) + ) + ] + return reordered_past diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py new file mode 100644 index 00000000..87fc1f07 --- /dev/null +++ b/server/text_generation_server/models/mpt.py @@ -0,0 +1,90 @@ +import torch +import torch.distributed + +from typing import Optional, Type +from opentelemetry import trace +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase +from huggingface_hub import hf_hub_download +import json + +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.custom_modeling.mpt_modeling import ( + MPTForCausalLM, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class MPTCausalLMBatch(CausalLMBatch): + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "CausalLMBatch": + batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) + batch.keys_head_dim_last = False + return batch + + +class MPTSharded(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 + else: + raise NotImplementedError("MPTSharded is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.pad_token = tokenizer.eos_token + + filename = hf_hub_download(model_id, revision=revision, filename="config.json") + with open(filename, "r") as f: + config = json.load(f) + config = PretrainedConfig(**config) + config.quantize = quantize + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + + config.quantize = quantize + model = MPTForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return MPTCausalLMBatch diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index a2b0c739..cbdfea66 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -31,7 +31,19 @@ def load_layer_norm(cls, prefix, weights, eps): return ln +@classmethod +def load_layer_norm_no_bias(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = nn.Parameter(weight) + ln.bias = None + return ln + + torch.nn.LayerNorm.load = load_layer_norm +torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias class FastLinear(nn.Module): From 8405581fcd9f5a3c81696c080232eb944f98ccf2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 4 Jul 2023 00:39:25 -0700 Subject: [PATCH 07/11] fix: Update server/Makefile to include Makefile-vllm (#520) # What does this PR do? For consistency and ease of use (you can just run `make` to install vllm without any extra steps). Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/server/Makefile b/server/Makefile index 17020c97..d0086928 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ include Makefile-flash-att +include Makefile-vllm unit-tests: pytest -s -vv -m "not private" tests From e6888d0e87c21bf392b62938c75c07ec92d5fcb8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 4 Jul 2023 18:35:37 +0200 Subject: [PATCH 08/11] docs(benchmarker): Adding some help for the options in `text-generation-benchmark`. (#462) --- benchmark/src/main.rs | 77 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 6172d377..a7550060 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -14,36 +14,85 @@ use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + /// The name of the tokenizer (as in model_id on the huggingface hub, or local path). #[clap(short, long, env)] tokenizer_name: String, + + /// The revision to use for the tokenizer if on the hub. #[clap(default_value = "main", long, env)] revision: String, + + /// The various batch sizes to benchmark for, the idea is to get enough + /// batching to start seeing increased latency, this usually means you're + /// moving from memory bound (usual as BS=1) to compute bound, and this is + /// a sweet spot for the maximum batch size for the model under test #[clap(short, long)] batch_size: Option>, + + /// This is the initial prompt sent to the text-generation-server length + /// in token. Longer prompt will slow down the benchmark. Usually the + /// latency grows somewhat linearly with this for the prefill step. + /// + /// Most importantly, the prefill step is usually not the one dominating + /// your runtime, so it's ok to keep it short. #[clap(default_value = "10", short, long, env)] sequence_length: u32, + + /// This is how many tokens will be generated by the server and averaged out + /// to give the `decode` latency. This is the *critical* number you want to optimize for + /// LLM spend most of their time doing decoding. + /// + /// Decode latency is usually quite stable. #[clap(default_value = "8", short, long, env)] decode_length: u32, + + ///How many runs should we average from #[clap(default_value = "10", short, long, env)] runs: usize, + + /// Number of warmup cycles #[clap(default_value = "1", short, long, env)] warmups: usize, - #[clap(long, env)] - temperature: Option, - #[clap(long, env)] - top_k: Option, - #[clap(long, env)] - top_p: Option, - #[clap(long, env)] - typical_p: Option, - #[clap(long, env)] - repetition_penalty: Option, - #[clap(long, env)] - watermark: bool, - #[clap(long, env)] - do_sample: bool, + + /// The location of the grpc socket. This benchmark tool bypasses the router + /// completely and directly talks to the gRPC processes #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)] master_shard_uds_path: String, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + temperature: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_k: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_p: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + typical_p: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + repetition_penalty: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + watermark: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + do_sample: bool, } fn main() -> Result<(), Box> { From 2a101207d44b903c1cc9b4d968a4b24150413942 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 4 Jul 2023 09:37:25 -0700 Subject: [PATCH 09/11] fix(server): Handle loading from local files for MPT (#534) This PR allows the MPT model to be loaded from local files. Without this change, an exception will be thrown by `hf_hub_download` function if `model_id` is a local path. --- server/text_generation_server/models/mpt.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 87fc1f07..3c0f8167 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -1,6 +1,7 @@ import torch import torch.distributed +from pathlib import Path from typing import Optional, Type from opentelemetry import trace from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase @@ -60,7 +61,12 @@ class MPTSharded(CausalLM): ) tokenizer.pad_token = tokenizer.eos_token - filename = hf_hub_download(model_id, revision=revision, filename="config.json") + # If model_id is a local path, load the file directly + local_path = Path(model_id, "config.json") + if local_path.exists(): + filename = str(local_path.resolve()) + else: + filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(filename, "r") as f: config = json.load(f) config = PretrainedConfig(**config) From e4b26aa10bd43c93cd236a9e3388692eb1e8a321 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 4 Jul 2023 11:11:33 -0700 Subject: [PATCH 10/11] fix(server): avoid errors for very small top_p values (#544) See https://github.com/huggingface/transformers/pull/24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers. --- server/text_generation_server/utils/logits_process.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 0cbbf8b0..f424eae4 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = probs <= self.top_p_opposite - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( From 31e2253ae721ea80032283b9e85ffe51945e5a55 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 4 Jul 2023 20:23:55 +0200 Subject: [PATCH 11/11] feat(server): use latest flash attention commit (#543) @njhill FYI --- server/Makefile-flash-att | 4 +- server/poetry.lock | 1909 +++++++++-------- server/pyproject.toml | 2 +- server/requirements.txt | 19 +- .../custom_modeling/flash_llama_modeling.py | 29 +- .../custom_modeling/flash_neox_modeling.py | 32 +- .../custom_modeling/flash_rw_modeling.py | 49 +- .../flash_santacoder_modeling.py | 29 +- .../models/flash_causal_lm.py | 52 +- server/text_generation_server/models/mpt.py | 4 +- server/text_generation_server/utils/layers.py | 2 +- 11 files changed, 1067 insertions(+), 1064 deletions(-) diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index 0e67a9e4..bc1d37ef 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,9 +1,9 @@ -flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f +flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec flash-attention: # Clone flash attention pip install packaging - git clone https://github.com/OlivierDehaene/flash-attention.git + git clone https://github.com/HazyResearch/flash-attention.git build-flash-attention: flash-attention cd flash-attention && git fetch && git checkout $(flash_att_commit) diff --git a/server/poetry.lock b/server/poetry.lock index 7c126e80..7d00f223 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,15 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. - [[package]] name = "accelerate" version = "0.19.0" description = "Accelerate" +category = "main" optional = true python-versions = ">=3.7.0" -files = [ - {file = "accelerate-0.19.0-py3-none-any.whl", hash = "sha256:2866b0bf9fff08f51e6384c95fa96725838b70f1988d1cce42e56b820d8a91dd"}, - {file = "accelerate-0.19.0.tar.gz", hash = "sha256:84920226b9e642e453ef37593ee55b956b08d8200dea4087c546c34e26157e76"}, -] [package.dependencies] numpy = ">=1.17" @@ -23,51 +18,818 @@ dev = ["black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-buil quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)", "urllib3 (<2.0.0)"] rich = ["rich"] sagemaker = ["sagemaker"] -test-dev = ["datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] -test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] -test-trackers = ["comet-ml", "tensorboard", "wandb"] +test_dev = ["datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] +test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] +test_trackers = ["comet-ml", "tensorboard", "wandb"] testing = ["datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"] [[package]] name = "backoff" version = "2.2.1" description = "Function decoration for backoff and retry" +category = "main" optional = false python-versions = ">=3.7,<4.0" -files = [ - {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, - {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, -] [[package]] name = "bitsandbytes" version = "0.38.1" description = "8-bit optimizers and matrix multiplication routines." +category = "main" optional = true python-versions = "*" -files = [ - {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"}, - {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"}, -] [[package]] name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" -files = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, -] [[package]] name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.7.0" -files = [ + +[[package]] +name = "click" +version = "8.1.3" +description = "Composable command line interface toolkit" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +category = "main" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" + +[[package]] +name = "Deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + +[[package]] +name = "einops" +version = "0.6.1" +description = "A new flavour of deep learning operations" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "exceptiongroup" +version = "1.1.2" +description = "Backport of PEP 654 (exception groups)" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.12.2" +description = "A platform independent file lock." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "fsspec" +version = "2023.6.0" +description = "File-system specification" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] +name = "googleapis-common-protos" +version = "1.59.1" +description = "Common protobufs used in Google APIs" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + +[[package]] +name = "grpc-interceptor" +version = "0.15.2" +description = "Simplifies gRPC interceptors" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" + +[package.dependencies] +grpcio = ">=1.49.1,<2.0.0" + +[package.extras] +testing = ["protobuf (>=4.21.9)"] + +[[package]] +name = "grpcio" +version = "1.56.0" +description = "HTTP/2-based RPC framework" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +protobuf = ["grpcio-tools (>=1.56.0)"] + +[[package]] +name = "grpcio-reflection" +version = "1.56.0" +description = "Standard Protobuf Reflection Service for gRPC" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +grpcio = ">=1.56.0" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-status" +version = "1.56.0" +description = "Status proto mapping for gRPC" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.56.0" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-tools" +version = "1.56.0" +description = "Protobuf code generator for gRPC" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +grpcio = ">=1.56.0" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + +[[package]] +name = "hf-transfer" +version = "0.1.3" +description = "" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "huggingface-hub" +version = "0.14.1" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +category = "main" +optional = false +python-versions = ">=3.7.0" + +[package.dependencies] +filelock = "*" +fsspec = "*" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "gradio", "jedi", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile"] +torch = ["torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" +optional = false +python-versions = ">=3.5" + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "Jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +category = "main" +optional = true +python-versions = ">=3.7" + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "loguru" +version = "0.6.0" +description = "Python logging made (stupidly) simple" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] + +[[package]] +name = "MarkupSafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +category = "main" +optional = true +python-versions = ">=3.7" + +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" +optional = true +python-versions = "*" + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "networkx" +version = "3.1" +description = "Python package for creating and manipulating graphs and networks" +category = "main" +optional = true +python-versions = ">=3.8" + +[package.extras] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] + +[[package]] +name = "numpy" +version = "1.25.0" +description = "Fundamental package for array computing in Python" +category = "main" +optional = false +python-versions = ">=3.9" + +[[package]] +name = "opentelemetry-api" +version = "1.15.0" +description = "OpenTelemetry Python API" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +deprecated = ">=1.2.6" +setuptools = ">=16.0" + +[[package]] +name = "opentelemetry-exporter-otlp" +version = "1.15.0" +description = "OpenTelemetry Collector Exporters" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +opentelemetry-exporter-otlp-proto-grpc = "1.15.0" +opentelemetry-exporter-otlp-proto-http = "1.15.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-grpc" +version = "1.15.0" +description = "OpenTelemetry Collector Protobuf over gRPC Exporter" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +googleapis-common-protos = ">=1.52,<2.0" +grpcio = ">=1.0.0,<2.0.0" +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-proto = "1.15.0" +opentelemetry-sdk = ">=1.12,<2.0" + +[package.extras] +test = ["pytest-grpc"] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.15.0" +description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +googleapis-common-protos = ">=1.52,<2.0" +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-proto = "1.15.0" +opentelemetry-sdk = ">=1.12,<2.0" +requests = ">=2.7,<3.0" + +[package.extras] +test = ["responses (==0.22.0)"] + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.36b0" +description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +opentelemetry-api = ">=1.4,<2.0" +setuptools = ">=16.0" +wrapt = ">=1.0.0,<2.0.0" + +[[package]] +name = "opentelemetry-instrumentation-grpc" +version = "0.36b0" +description = "OpenTelemetry gRPC instrumentation" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-instrumentation = "0.36b0" +opentelemetry-sdk = ">=1.12,<2.0" +opentelemetry-semantic-conventions = "0.36b0" +wrapt = ">=1.0.0,<2.0.0" + +[package.extras] +instruments = ["grpcio (>=1.27,<2.0)"] +test = ["opentelemetry-instrumentation-grpc[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.36b0)", "protobuf (>=3.13,<4.0)"] + +[[package]] +name = "opentelemetry-proto" +version = "1.15.0" +description = "OpenTelemetry Python Proto" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +protobuf = ">=3.19,<5.0" + +[[package]] +name = "opentelemetry-sdk" +version = "1.15.0" +description = "OpenTelemetry Python SDK" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +opentelemetry-api = "1.15.0" +opentelemetry-semantic-conventions = "0.36b0" +setuptools = ">=16.0" +typing-extensions = ">=3.7.4" + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.36b0" +description = "OpenTelemetry Semantic Conventions" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "packaging" +version = "23.1" +description = "Core utilities for Python packages" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "pluggy" +version = "1.2.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "protobuf" +version = "4.23.3" +description = "" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "psutil" +version = "5.9.5" +description = "Cross-platform lib for process and system monitoring in Python." +category = "main" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + +[[package]] +name = "pytest" +version = "7.4.0" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "PyYAML" +version = "6.0" +description = "YAML parser and emitter for Python" +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "regex" +version = "2023.6.3" +description = "Alternative regular expression module, to replace re." +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "safetensors" +version = "0.3.1" +description = "Fast and Safe Tensor serialization" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] +dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] +torch = ["torch (>=1.10)"] + +[[package]] +name = "sentencepiece" +version = "0.1.99" +description = "SentencePiece python wrapper" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "setuptools" +version = "68.0.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +category = "main" +optional = true +python-versions = ">=3.8" + +[package.dependencies] +mpmath = ">=0.19" + +[[package]] +name = "tokenizers" +version = "0.13.3" +description = "Fast and Customizable Tokenizers" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +category = "dev" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "torch" +version = "2.0.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = true +python-versions = ">=3.8.0" + +[package.dependencies] +filelock = "*" +jinja2 = "*" +networkx = "*" +sympy = "*" +typing-extensions = "*" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] + +[[package]] +name = "tqdm" +version = "4.65.0" +description = "Fast, Extensible Progress Meter" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["py-make (>=0.1.0)", "twine", "wheel"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "transformers" +version = "4.29.2" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +category = "main" +optional = false +python-versions = ">=3.7.0" + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.14.1,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.19.0)"] +agents = ["Pillow", "accelerate (>=0.19.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] +all = ["Pillow", "accelerate (>=0.19.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "numba (<0.57.0)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.2)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "numba (<0.57.0)", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.19.0)", "deepspeed (>=0.8.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.19.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.8.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf (<=3.20.2)", "psutil", "pytest", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.19.0)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.6.9)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "numba (<0.57.0)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.2)", "psutil", "pyctcdecode (>=0.4.0)", "pytest", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "numba (<0.57.0)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.2)", "psutil", "pyctcdecode (>=0.4.0)", "pytest", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.19.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "numba (<0.57.0)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.2)", "psutil", "pyctcdecode (>=0.4.0)", "pytest", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow", "accelerate (>=0.19.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "numba (<0.57.0)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.2)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +docs_specific = ["hf-doc-builder"] +fairscale = ["fairscale (>0.3)"] +flax = ["flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8,<=0.1.4)"] +flax-speech = ["kenlm", "librosa", "numba (<0.57.0)", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune]", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] +ray = ["ray[tune]"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf (<=3.20.2)", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "numba (<0.57.0)", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf (<=3.20.2)", "psutil", "pytest", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "numba (<0.57.0)", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] +torch = ["accelerate (>=0.19.0)", "torch (>=1.9,!=1.12.0)"] +torch-speech = ["kenlm", "librosa", "numba (<0.57.0)", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.14.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.2)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow"] + +[[package]] +name = "typer" +version = "0.6.1" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +click = ">=7.1.1,<9.0.0" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + +[[package]] +name = "typing-extensions" +version = "4.7.1" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "urllib3" +version = "2.0.3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + +[[package]] +name = "wrapt" +version = "1.15.0" +description = "Module for decorators, wrappers and monkey patching." +category = "main" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[extras] +accelerate = ["accelerate"] +bnb = ["bitsandbytes"] + +[metadata] +lock-version = "1.1" +python-versions = "^3.9" +content-hash = "65afc4bfa07da4b1427d269fa745939da3851eaede9a8478f5a4bf5949d32cc9" + +[metadata.files] +accelerate = [ + {file = "accelerate-0.19.0-py3-none-any.whl", hash = "sha256:2866b0bf9fff08f51e6384c95fa96725838b70f1988d1cce42e56b820d8a91dd"}, + {file = "accelerate-0.19.0.tar.gz", hash = "sha256:84920226b9e642e453ef37593ee55b956b08d8200dea4087c546c34e26157e76"}, +] +backoff = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] +bitsandbytes = [ + {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"}, + {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"}, +] +certifi = [ + {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, + {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, +] +charset-normalizer = [ {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"}, {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b"}, {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60"}, @@ -144,312 +906,145 @@ files = [ {file = "charset_normalizer-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b"}, {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, ] - -[[package]] -name = "click" -version = "8.1.3" -description = "Composable command line interface toolkit" -optional = false -python-versions = ">=3.7" -files = [ +click = [ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[[package]] -name = "colorama" -version = "0.4.6" -description = "Cross-platform colored terminal text." -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -files = [ +colorama = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] - -[[package]] -name = "deprecated" -version = "1.2.14" -description = "Python @deprecated decorator to deprecate old python classes, functions or methods." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ +Deprecated = [ {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, ] - -[package.dependencies] -wrapt = ">=1.10,<2" - -[package.extras] -dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] - -[[package]] -name = "einops" -version = "0.6.1" -description = "A new flavour of deep learning operations" -optional = false -python-versions = ">=3.7" -files = [ +einops = [ {file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"}, {file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"}, ] - -[[package]] -name = "exceptiongroup" -version = "1.1.1" -description = "Backport of PEP 654 (exception groups)" -optional = false -python-versions = ">=3.7" -files = [ - {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, - {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, +exceptiongroup = [ + {file = "exceptiongroup-1.1.2-py3-none-any.whl", hash = "sha256:e346e69d186172ca7cf029c8c1d16235aa0e04035e5750b4b95039e65204328f"}, + {file = "exceptiongroup-1.1.2.tar.gz", hash = "sha256:12c3e887d6485d16943a309616de20ae5582633e0a2eda17f4e10fd61c1e8af5"}, ] - -[package.extras] -test = ["pytest (>=6)"] - -[[package]] -name = "filelock" -version = "3.12.2" -description = "A platform independent file lock." -optional = false -python-versions = ">=3.7" -files = [ +filelock = [ {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, ] - -[package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] - -[[package]] -name = "fsspec" -version = "2023.6.0" -description = "File-system specification" -optional = false -python-versions = ">=3.8" -files = [ +fsspec = [ {file = "fsspec-2023.6.0-py3-none-any.whl", hash = "sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a"}, {file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"}, ] - -[package.extras] -abfs = ["adlfs"] -adl = ["adlfs"] -arrow = ["pyarrow (>=1)"] -dask = ["dask", "distributed"] -devel = ["pytest", "pytest-cov"] -dropbox = ["dropbox", "dropboxdrivefs", "requests"] -full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] -fuse = ["fusepy"] -gcs = ["gcsfs"] -git = ["pygit2"] -github = ["requests"] -gs = ["gcsfs"] -gui = ["panel"] -hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] -libarchive = ["libarchive-c"] -oci = ["ocifs"] -s3 = ["s3fs"] -sftp = ["paramiko"] -smb = ["smbprotocol"] -ssh = ["paramiko"] -tqdm = ["tqdm"] - -[[package]] -name = "googleapis-common-protos" -version = "1.59.1" -description = "Common protobufs used in Google APIs" -optional = false -python-versions = ">=3.7" -files = [ +googleapis-common-protos = [ {file = "googleapis-common-protos-1.59.1.tar.gz", hash = "sha256:b35d530fe825fb4227857bc47ad84c33c809ac96f312e13182bdeaa2abe1178a"}, {file = "googleapis_common_protos-1.59.1-py2.py3-none-any.whl", hash = "sha256:0cbedb6fb68f1c07e18eb4c48256320777707e7d0c55063ae56c15db3224a61e"}, ] - -[package.dependencies] -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" - -[package.extras] -grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] - -[[package]] -name = "grpc-interceptor" -version = "0.15.2" -description = "Simplifies gRPC interceptors" -optional = false -python-versions = ">=3.7,<4.0" -files = [ +grpc-interceptor = [ {file = "grpc-interceptor-0.15.2.tar.gz", hash = "sha256:5c984110af4fb77d03472ec0468f9c77ddaf798e190410fb7b7f1e76c60c96a4"}, {file = "grpc_interceptor-0.15.2-py3-none-any.whl", hash = "sha256:596dac3cb709ffb6178a4873f5148e254c871c9069f0b11040189b257969490a"}, ] - -[package.dependencies] -grpcio = ">=1.49.1,<2.0.0" - -[package.extras] -testing = ["protobuf (>=4.21.9)"] - -[[package]] -name = "grpcio" -version = "1.54.2" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.7" -files = [ - {file = "grpcio-1.54.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c"}, - {file = "grpcio-1.54.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea"}, - {file = "grpcio-1.54.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c"}, - {file = "grpcio-1.54.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148"}, - {file = "grpcio-1.54.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8"}, - {file = "grpcio-1.54.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08"}, - {file = "grpcio-1.54.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee"}, - {file = "grpcio-1.54.2-cp310-cp310-win32.whl", hash = "sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8"}, - {file = "grpcio-1.54.2-cp310-cp310-win_amd64.whl", hash = "sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3"}, - {file = "grpcio-1.54.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821"}, - {file = "grpcio-1.54.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7"}, - {file = "grpcio-1.54.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e"}, - {file = "grpcio-1.54.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5"}, - {file = "grpcio-1.54.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a"}, - {file = "grpcio-1.54.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796"}, - {file = "grpcio-1.54.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610"}, - {file = "grpcio-1.54.2-cp311-cp311-win32.whl", hash = "sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e"}, - {file = "grpcio-1.54.2-cp311-cp311-win_amd64.whl", hash = "sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b"}, - {file = "grpcio-1.54.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c"}, - {file = "grpcio-1.54.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94"}, - {file = "grpcio-1.54.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2"}, - {file = "grpcio-1.54.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b"}, - {file = "grpcio-1.54.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe"}, - {file = "grpcio-1.54.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb"}, - {file = "grpcio-1.54.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb"}, - {file = "grpcio-1.54.2-cp37-cp37m-win_amd64.whl", hash = "sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474"}, - {file = "grpcio-1.54.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a"}, - {file = "grpcio-1.54.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf"}, - {file = "grpcio-1.54.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3"}, - {file = "grpcio-1.54.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05"}, - {file = "grpcio-1.54.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598"}, - {file = "grpcio-1.54.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb"}, - {file = "grpcio-1.54.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4"}, - {file = "grpcio-1.54.2-cp38-cp38-win32.whl", hash = "sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771"}, - {file = "grpcio-1.54.2-cp38-cp38-win_amd64.whl", hash = "sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12"}, - {file = "grpcio-1.54.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98"}, - {file = "grpcio-1.54.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1"}, - {file = "grpcio-1.54.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f"}, - {file = "grpcio-1.54.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f"}, - {file = "grpcio-1.54.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c"}, - {file = "grpcio-1.54.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a"}, - {file = "grpcio-1.54.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c"}, - {file = "grpcio-1.54.2-cp39-cp39-win32.whl", hash = "sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa"}, - {file = "grpcio-1.54.2-cp39-cp39-win_amd64.whl", hash = "sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b"}, - {file = "grpcio-1.54.2.tar.gz", hash = "sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019"}, +grpcio = [ + {file = "grpcio-1.56.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:fb34ace11419f1ae321c36ccaa18d81cd3f20728cd191250be42949d6845bb2d"}, + {file = "grpcio-1.56.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:008767c0aed4899e657b50f2e0beacbabccab51359eba547f860e7c55f2be6ba"}, + {file = "grpcio-1.56.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:17f47aeb9be0da5337f9ff33ebb8795899021e6c0741ee68bd69774a7804ca86"}, + {file = "grpcio-1.56.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43c50d810cc26349b093bf2cfe86756ab3e9aba3e7e681d360930c1268e1399a"}, + {file = "grpcio-1.56.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:187b8f71bad7d41eea15e0c9812aaa2b87adfb343895fffb704fb040ca731863"}, + {file = "grpcio-1.56.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:881575f240eb5db72ddca4dc5602898c29bc082e0d94599bf20588fb7d1ee6a0"}, + {file = "grpcio-1.56.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c243b158dd7585021d16c50498c4b2ec0a64a6119967440c5ff2d8c89e72330e"}, + {file = "grpcio-1.56.0-cp310-cp310-win32.whl", hash = "sha256:8b3b2c7b5feef90bc9a5fa1c7f97637e55ec3e76460c6d16c3013952ee479cd9"}, + {file = "grpcio-1.56.0-cp310-cp310-win_amd64.whl", hash = "sha256:03a80451530fd3b8b155e0c4480434f6be669daf7ecba56f73ef98f94222ee01"}, + {file = "grpcio-1.56.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:64bd3abcf9fb4a9fa4ede8d0d34686314a7075f62a1502217b227991d9ca4245"}, + {file = "grpcio-1.56.0-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:fdc3a895791af4addbb826808d4c9c35917c59bb5c430d729f44224e51c92d61"}, + {file = "grpcio-1.56.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:4f84a6fd4482e5fe73b297d4874b62a535bc75dc6aec8e9fe0dc88106cd40397"}, + {file = "grpcio-1.56.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14e70b4dda3183abea94c72d41d5930c333b21f8561c1904a372d80370592ef3"}, + {file = "grpcio-1.56.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b5ce42a5ebe3e04796246ba50357f1813c44a6efe17a37f8dc7a5c470377312"}, + {file = "grpcio-1.56.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8219f17baf069fe8e42bd8ca0b312b875595e43a70cabf397be4fda488e2f27d"}, + {file = "grpcio-1.56.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:defdd14b518e6e468466f799aaa69db0355bca8d3a5ea75fb912d28ba6f8af31"}, + {file = "grpcio-1.56.0-cp311-cp311-win32.whl", hash = "sha256:50f4daa698835accbbcc60e61e0bc29636c0156ddcafb3891c987e533a0031ba"}, + {file = "grpcio-1.56.0-cp311-cp311-win_amd64.whl", hash = "sha256:59c4e606993a47146fbeaf304b9e78c447f5b9ee5641cae013028c4cca784617"}, + {file = "grpcio-1.56.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:b1f4b6f25a87d80b28dd6d02e87d63fe1577fe6d04a60a17454e3f8077a38279"}, + {file = "grpcio-1.56.0-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:c2148170e01d464d41011a878088444c13413264418b557f0bdcd1bf1b674a0e"}, + {file = "grpcio-1.56.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:0409de787ebbf08c9d2bca2bcc7762c1efe72eada164af78b50567a8dfc7253c"}, + {file = "grpcio-1.56.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66f0369d27f4c105cd21059d635860bb2ea81bd593061c45fb64875103f40e4a"}, + {file = "grpcio-1.56.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38fdf5bd0a1c754ce6bf9311a3c2c7ebe56e88b8763593316b69e0e9a56af1de"}, + {file = "grpcio-1.56.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:79d4c5911d12a7aa671e5eb40cbb50a830396525014d2d6f254ea2ba180ce637"}, + {file = "grpcio-1.56.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:5d2fc471668a7222e213f86ef76933b18cdda6a51ea1322034478df8c6519959"}, + {file = "grpcio-1.56.0-cp37-cp37m-win_amd64.whl", hash = "sha256:991224fd485e088d3cb5e34366053691a4848a6b7112b8f5625a411305c26691"}, + {file = "grpcio-1.56.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c6f36621aabecbaff3e70c4d1d924c76c8e6a7ffec60c331893640a4af0a8037"}, + {file = "grpcio-1.56.0-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:1eadd6de258901929223f422ffed7f8b310c0323324caf59227f9899ea1b1674"}, + {file = "grpcio-1.56.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:72836b5a1d4f508ffbcfe35033d027859cc737972f9dddbe33fb75d687421e2e"}, + {file = "grpcio-1.56.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f92a99ab0c7772fb6859bf2e4f44ad30088d18f7c67b83205297bfb229e0d2cf"}, + {file = "grpcio-1.56.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa08affbf672d051cd3da62303901aeb7042a2c188c03b2c2a2d346fc5e81c14"}, + {file = "grpcio-1.56.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2db108b4c8e29c145e95b0226973a66d73ae3e3e7fae00329294af4e27f1c42"}, + {file = "grpcio-1.56.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8674fdbd28266d8efbcddacf4ec3643f76fe6376f73283fd63a8374c14b0ef7c"}, + {file = "grpcio-1.56.0-cp38-cp38-win32.whl", hash = "sha256:bd55f743e654fb050c665968d7ec2c33f03578a4bbb163cfce38024775ff54cc"}, + {file = "grpcio-1.56.0-cp38-cp38-win_amd64.whl", hash = "sha256:c63bc5ac6c7e646c296fed9139097ae0f0e63f36f0864d7ce431cce61fe0118a"}, + {file = "grpcio-1.56.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:c0bc9dda550785d23f4f025be614b7faa8d0293e10811f0f8536cf50435b7a30"}, + {file = "grpcio-1.56.0-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d596408bab632ec7b947761e83ce6b3e7632e26b76d64c239ba66b554b7ee286"}, + {file = "grpcio-1.56.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:76b6e6e1ee9bda32e6e933efd61c512e9a9f377d7c580977f090d1a9c78cca44"}, + {file = "grpcio-1.56.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7beb84ebd0a3f732625124b73969d12b7350c5d9d64ddf81ae739bbc63d5b1ed"}, + {file = "grpcio-1.56.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83ec714bbbe9b9502177c842417fde39f7a267031e01fa3cd83f1ca49688f537"}, + {file = "grpcio-1.56.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4feee75565d1b5ab09cb3a5da672b84ca7f6dd80ee07a50f5537207a9af543a4"}, + {file = "grpcio-1.56.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b4638a796778329cc8e142e4f57c705adb286b3ba64e00b0fa91eeb919611be8"}, + {file = "grpcio-1.56.0-cp39-cp39-win32.whl", hash = "sha256:437af5a7673bca89c4bc0a993382200592d104dd7bf55eddcd141cef91f40bab"}, + {file = "grpcio-1.56.0-cp39-cp39-win_amd64.whl", hash = "sha256:4241a1c2c76e748023c834995cd916570e7180ee478969c2d79a60ce007bc837"}, + {file = "grpcio-1.56.0.tar.gz", hash = "sha256:4c08ee21b3d10315b8dc26f6c13917b20ed574cdbed2d2d80c53d5508fdcc0f2"}, ] - -[package.extras] -protobuf = ["grpcio-tools (>=1.54.2)"] - -[[package]] -name = "grpcio-reflection" -version = "1.54.2" -description = "Standard Protobuf Reflection Service for gRPC" -optional = false -python-versions = ">=3.6" -files = [ - {file = "grpcio-reflection-1.54.2.tar.gz", hash = "sha256:b2e021e1ce4f075615411edfbbd6fdcc485ba474dd6e5a3f559690582959a673"}, - {file = "grpcio_reflection-1.54.2-py3-none-any.whl", hash = "sha256:e7759addebbd90768f3a0278320278145758c4687d9e2cd7d76e7cbd0e329274"}, +grpcio-reflection = [ + {file = "grpcio-reflection-1.56.0.tar.gz", hash = "sha256:d6bad11af658c78170b50825cf3c841f69662f946f8e921c9fa1ae2dc48f56c1"}, + {file = "grpcio_reflection-1.56.0-py3-none-any.whl", hash = "sha256:23572e1cc5674465e54fb5898d7358288e6f6970e2aa21c9eee7ccd674b7388e"}, ] - -[package.dependencies] -grpcio = ">=1.54.2" -protobuf = ">=4.21.6" - -[[package]] -name = "grpcio-status" -version = "1.54.2" -description = "Status proto mapping for gRPC" -optional = false -python-versions = ">=3.6" -files = [ - {file = "grpcio-status-1.54.2.tar.gz", hash = "sha256:3255cbec5b7c706caa3d4dd584606c080e6415e15631bb2f6215e2b70055836d"}, - {file = "grpcio_status-1.54.2-py3-none-any.whl", hash = "sha256:2a7cb4838225f1b53bd0448a3008c5b5837941e1f3a0b13fa38768f08a7b68c2"}, +grpcio-status = [ + {file = "grpcio-status-1.56.0.tar.gz", hash = "sha256:9eca0b2dcda0782d3702df225918efd6d820f75f93cd5c51c7fb6a4ffbfea12c"}, + {file = "grpcio_status-1.56.0-py3-none-any.whl", hash = "sha256:e5f101c96686e9d4e94a114567960fdb00052aa3c818b029745e3db37dc9c613"}, ] - -[package.dependencies] -googleapis-common-protos = ">=1.5.5" -grpcio = ">=1.54.2" -protobuf = ">=4.21.6" - -[[package]] -name = "grpcio-tools" -version = "1.54.2" -description = "Protobuf code generator for gRPC" -optional = false -python-versions = ">=3.7" -files = [ - {file = "grpcio-tools-1.54.2.tar.gz", hash = "sha256:e11c2c2aee53f340992e8e4d6a59172cbbbd0193f1351de98c4f810a5041d5ca"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:2b96f5f17d3156058be247fd25b062b4768138665694c00b056659618b8fb418"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:11939c9a8a39bd4815c7e88cb2fee48e1948775b59dbb06de8fcae5991e84f9e"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:129de5579f95d6a55dde185f188b4cbe19d1e2f1471425431d9930c31d300d70"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4128c01cd6f5ea8f7c2db405dbfd8582cd967d36e6fa0952565436633b0e591"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5c7292dd899ad8fa09a2be96719648cee37b17909fe8c12007e3bff58ebee61"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5ef30c2dbc63c1e0a462423ca4f95001814d26ef4fe66208e53fcf220ea3b717"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4abfc1892380abe6cef381eab86f9350cbd703bfe5d834095aa66fd91c886b6d"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-win32.whl", hash = "sha256:9acf443dcf6f68fbea3b7fb519e1716e014db1a561939f5aecc4abda74e4015d"}, - {file = "grpcio_tools-1.54.2-cp310-cp310-win_amd64.whl", hash = "sha256:21b9d2dee80f3f77e4097252e7f0db89772335a7300b72ab3d2e5c280872b1db"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:7b24fbab9e7598518ce4549e066df00aab79c2bf9bedcdde23fb5ef6a3cf532f"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:7baa210c20f71a242d9ae0e02734628f6948e8bee3bf538647894af427d28800"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:e3d0e5188ff8dbaddac2ee44731d36f09c4eccd3eac7328e547862c44f75cacd"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27671c68c7e0e3c5ff9967f5500799f65a04e7b153b8ce10243c87c43199039d"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f39d8e8806b8857fb473ca6a9c7bd800b0673dfdb7283ff569af0345a222f32c"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8e4c5a48f7b2e8798ce381498ee7b9a83c65b87ae66ee5022387394e5eb51771"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f285f8ef3de422717a36bd372239ae778b8cc112ce780ca3c7fe266dadc49fb"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-win32.whl", hash = "sha256:0f952c8a5c47e9204fe8959f7e9add149e660f6579d67cf65024c32736d34caf"}, - {file = "grpcio_tools-1.54.2-cp311-cp311-win_amd64.whl", hash = "sha256:3237149beec39e897fd62cef4aa1e1cd9422d7a95661d24bd0a79200b167e730"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:0ab1b323905d449298523db5d34fa5bf5fffd645bd872b25598e2f8a01f0ea39"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:7d7e6e8d62967b3f037f952620cb7381cc39a4bd31790c75fcfba56cc975d70b"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7f4624ef2e76a3a5313c4e61a81be38bcc16b59a68a85d30758b84cd2102b161"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e543f457935ba7b763b121f1bf893974393b4d30065042f947f85a8d81081b80"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0239b929eb8b3b30b2397eef3b9abb245087754d77c3721e3be43c44796de87d"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:0de05c7698c655e9a240dc34ae91d6017b93143ac89e5b20046d7ca3bd09c27c"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a3ce0b98fb581c471424d2cda45120f57658ed97677c6fec4d6decf5d7c1b976"}, - {file = "grpcio_tools-1.54.2-cp37-cp37m-win_amd64.whl", hash = "sha256:37393ef90674964175923afe3859fc5a208e1ece565f642b4f76a8c0224a0993"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:8e4531267736d88fde1022b36dd42ed8163e3575bcbd12bfed96662872aa93fe"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:a0b7049814442f918b522d66b1d015286afbeb9e6d141af54bbfafe31710a3c8"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:b80585e06c4f0082327eb5c9ad96fbdb2b0e7c14971ea5099fe78c22f4608451"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39fd530cfdf58dc05125775cc233b05554d553d27478f14ae5fd8a6306f0cb28"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb9ec4aea0f2b3006fb002fa59e5c10f92b48fc374619fbffd14d2b0e388c3e"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d512de051342a576bb89777476d13c5266d9334cf4badb6468aed9dc8f5bdec1"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1b8ee3099c51ce987fa8a08e6b93fc342b10228415dd96b5c0caa0387f636a6f"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-win32.whl", hash = "sha256:6037f123905dc0141f7c8383ca616ef0195e79cd3b4d82faaee789d4045e891b"}, - {file = "grpcio_tools-1.54.2-cp38-cp38-win_amd64.whl", hash = "sha256:10dd41862f579d185c60f629b5ee89103e216f63b576079d258d974d980bad87"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:f6787d07fdab31a32c433c1ba34883dea6559d8a3fbe08fb93d834ca34136b71"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:21b1467e31e44429d2a78b50135c9cdbd4b8f6d3b5cd548bc98985d3bdc352d0"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:30a49b8b168aced2a4ff40959e6c4383ad6cfd7a20839a47a215e9837eb722dc"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8742122782953d2fd038f0a199f047a24e941cc9718b1aac90876dbdb7167739"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503ef1351c62fb1d6747eaf74932b609d8fdd4345b3591ef910adef8fa9969d0"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:72d15de4c4b6a764a76c4ae69d99c35f7a0751223688c3f7e62dfa95eb4f61be"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:df079479fb1b9e488334312e35ebbf30cbf5ecad6c56599f1a961800b33ab7c1"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-win32.whl", hash = "sha256:49c2846dcc4803476e839d8bd4db8845e928f19130e0ea86121f2d1f43d2b452"}, - {file = "grpcio_tools-1.54.2-cp39-cp39-win_amd64.whl", hash = "sha256:b82ca472db9c914c44e39a41e9e8bd3ed724523dd7aff5ce37592b8d16920ed9"}, +grpcio-tools = [ + {file = "grpcio-tools-1.56.0.tar.gz", hash = "sha256:39f5877cea514b3da9f2683dfb3ffb45ef47b05f4ff39c287d7d61c5057f48b8"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:cdbae7312e6d132d38ec2c1611b8cafb783e0416cc5c6deae04efde5f16fb190"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5f5c416b88d76fbdb548cfee0486928748816b700ece6e591006e5b1dc67598f"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:23e2ef1dc6a9bf766f091e2c52a68e54d0aff3548f94562e61fb0ac3874d514a"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8870ab60f8a76b4a7e43184ee03d28112b976d83c43d41cec821f47b3a297da2"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e59ab6c0bf4a8bb975553ad578d4425bd192775ae384f9406d77d31ad00f6efe"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b309659534b5d930f9ab6d521670c2dd86cb6ef7f47f37f73f96557e2ec13a49"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8115b416ea2cad8a87dc3aadfaf26da684e003c3770b12e7219b462505bb5b85"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-win32.whl", hash = "sha256:e4cb62a521efbca4cb1ad50233aa400574b3daaf6eb26707d661a0afe8191d92"}, + {file = "grpcio_tools-1.56.0-cp310-cp310-win_amd64.whl", hash = "sha256:4d59009ed52220eb2d62f5cefa4e58dec930fb92fab27bb390c4cf1d360ac7e1"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:cd69107705794e815a8b262722c6fea995911cb1dfc1310abf63b476165335d6"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:2d1ee9e13ce135a6ed451b428ef14af131dc7df2551a5344ff4f8aee2d9fab99"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:142530b9fdfabe04f0c7e5dacd45b6c419d39704fa439cc0aabf73ea0d8f916d"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b7a4eb5003a29eecd71707589f93ae7e8fa2e681366a811b3f86695055d8666"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa6d9bdd75d3625dae38372b43696e159c10aa98719b4302b1e94f1ff7878d47"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c43b4fe8c8df4c52d3106bba2cf427f0e46bbebb80e127fbbc3134db0fead7be"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:168940a4a955b6c65da978dbf62e1c36e3a311bb27f649fd201a228e2583a6d4"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-win32.whl", hash = "sha256:3a4b06169493f9454a7f2516c5d41b566d9734e553bbc505f2a7837f7f4a2df1"}, + {file = "grpcio_tools-1.56.0-cp311-cp311-win_amd64.whl", hash = "sha256:1bd361fcc967c21672ba855fc77ea0e7afa51664033a746df96545f84edc4670"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:7e6bcb194b81e372411494d8ed69fab89aa3452b7275fce4f7917fbe7b04fb72"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:02b23a12b91287ebea14b3685735d1d675e77c3cd365ec1771c3e9afbeba1ec6"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:80d75856f8ec949847386ad2f56a460f21c63bf82ce99ca5b6aa512c0b875fb1"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9cffff0b4af80285fa49637d69b69d640eb775dc74b23635e4de5faad9e7e744"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3de6c08b545920a39b31ed13305f946c00b19ac1b13d26119f111b6360f22ccf"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:128bb13fe9a2681eeb08175f5fbc8e2d8953d7d0dd240e96f9244b9d2547a1aa"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b57f7f01eafbfe3a293f2efffb675774dbe4074c4627975ec4dc4aa5766801fb"}, + {file = "grpcio_tools-1.56.0-cp37-cp37m-win_amd64.whl", hash = "sha256:282176066fb082ad21c403b84f9d6b440a20482e6f52b83bb2adf54d6fdcae9f"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:d9b8d1c42854d3433c058795f52b1418b53dd8c1e9811fecb1312202e803a2c5"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:accf713f51da74b1a18aa4b31df0ab135510704661f735a938081777b79a4c25"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:ac33fd2d02d24101ea389be8e05b928acb58be56403d4ebc3aecfab473fa4a25"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4acdc7b957abfd76581717f0ac8e4408e0a85b7d0ac8d2cdf4d964f16926b897"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79291bfb1fe5f21d99f4839f43d3c5d44c5402c830a24dbb2811d785dd21264b"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0a8767e4de0f573c678313c5de075ac0e163a192bb135018e45015a22f234387"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:96fe2f7f5805d88cb7f2e3e3502550b2883dfab0f9efcf3cbd444942cf2ee1da"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-win32.whl", hash = "sha256:21cf32ccffd4f1800b0dcdf58aa1fc7f626795c9da784c3d817c944edcf2d3ae"}, + {file = "grpcio_tools-1.56.0-cp38-cp38-win_amd64.whl", hash = "sha256:f3ab1a9fad636302f7307d143f64a9fbd11bc041652bf53bb016006e9a5ca820"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:8989d363ac1996238fee61c8f5663f15a8fc362cb1e758c4a686b76cb457cd70"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:11cdd9cbf0c09c3a761c6f59dfd7128104be7cd393334efe386d4fc3f990ee1a"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:5fd4c005a4afec16578849bc522ddf3298d6d499b3d37bf51314b086c714cdd5"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7302acaa07cf4966c926fcd6a60c8d30a697f730c38168bf83e1519b464115b"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c1c43d185ebf904c3deec23c36ca2ba4e95db999cf00fc8f85eda4551622a26"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b12bb8c1d408ae40e4c806a3a8ebda2d107310e46696e1da13d0dc3f91fbd19d"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:781cf09e4d5c9288708f6ec9c3eae64d9d5a0f4c46c7ebe70ebb7ab4f6384789"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-win32.whl", hash = "sha256:c62f07452dee3f1ed23aeaef821797c5e516f79535e97fe6a6b0a0ee8db1cc91"}, + {file = "grpcio_tools-1.56.0-cp39-cp39-win_amd64.whl", hash = "sha256:7f063443870650e55012fdb3a58ff4ce5f4042b81dad6b749333ee8146157511"}, ] - -[package.dependencies] -grpcio = ">=1.54.2" -protobuf = ">=4.21.6,<5.0dev" -setuptools = "*" - -[[package]] -name = "hf-transfer" -version = "0.1.3" -description = "" -optional = false -python-versions = ">=3.7" -files = [ +hf-transfer = [ {file = "hf_transfer-0.1.3-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:862b6ddba8e236bdc73408c20d020cfe5069cac3fd0b6de901c46f031df2b7d9"}, {file = "hf_transfer-0.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:569ef1ec6fec182e706ade4ea0c63f8510fd618ed7ced7c772efaafac7245b07"}, {file = "hf_transfer-0.1.3-cp310-none-win_amd64.whl", hash = "sha256:c9faa88b3491c50d4aa75faf18ae24040cd91aa0565c7f7ba2357dbcbf8372f6"}, @@ -470,102 +1065,27 @@ files = [ {file = "hf_transfer-0.1.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efb8b41360c7e3d7700c147b70688aed0a03e86fbe5bcfdee079b0e634f026f9"}, {file = "hf_transfer-0.1.3.tar.gz", hash = "sha256:7afd7eb03efad7812a48591b639b2e3f3d1f93c1e9060c18cc63ebf08d7e193c"}, ] - -[[package]] -name = "huggingface-hub" -version = "0.14.1" -description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = false -python-versions = ">=3.7.0" -files = [ +huggingface-hub = [ {file = "huggingface_hub-0.14.1-py3-none-any.whl", hash = "sha256:9fc619170d800ff3793ad37c9757c255c8783051e1b5b00501205eb43ccc4f27"}, {file = "huggingface_hub-0.14.1.tar.gz", hash = "sha256:9ab899af8e10922eac65e290d60ab956882ab0bf643e3d990b1394b6b47b7fbc"}, ] - -[package.dependencies] -filelock = "*" -fsspec = "*" -packaging = ">=20.9" -pyyaml = ">=5.1" -requests = "*" -tqdm = ">=4.42.1" -typing-extensions = ">=3.7.4.3" - -[package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] -cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] -fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] -tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "gradio", "jedi", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile"] -torch = ["torch"] -typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] - -[[package]] -name = "idna" -version = "3.4" -description = "Internationalized Domain Names in Applications (IDNA)" -optional = false -python-versions = ">=3.5" -files = [ +idna = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] - -[[package]] -name = "iniconfig" -version = "2.0.0" -description = "brain-dead simple config-ini parsing" -optional = false -python-versions = ">=3.7" -files = [ +iniconfig = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] - -[[package]] -name = "jinja2" -version = "3.1.2" -description = "A very fast and expressive template engine." -optional = true -python-versions = ">=3.7" -files = [ +Jinja2 = [ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, ] - -[package.dependencies] -MarkupSafe = ">=2.0" - -[package.extras] -i18n = ["Babel (>=2.7)"] - -[[package]] -name = "loguru" -version = "0.6.0" -description = "Python logging made (stupidly) simple" -optional = false -python-versions = ">=3.5" -files = [ +loguru = [ {file = "loguru-0.6.0-py3-none-any.whl", hash = "sha256:4e2414d534a2ab57573365b3e6d0234dfb1d84b68b7f3b948e6fb743860a77c3"}, {file = "loguru-0.6.0.tar.gz", hash = "sha256:066bd06758d0a513e9836fd9c6b5a75bfb3fd36841f4b996bc60b547a309d41c"}, ] - -[package.dependencies] -colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} -win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} - -[package.extras] -dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] - -[[package]] -name = "markupsafe" -version = "2.1.3" -description = "Safely add untrusted strings to HTML/XML markup." -optional = true -python-versions = ">=3.7" -files = [ +MarkupSafe = [ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, @@ -617,288 +1137,101 @@ files = [ {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] - -[[package]] -name = "mpmath" -version = "1.3.0" -description = "Python library for arbitrary-precision floating-point arithmetic" -optional = true -python-versions = "*" -files = [ +mpmath = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, ] - -[package.extras] -develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] -docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4)"] -tests = ["pytest (>=4.6)"] - -[[package]] -name = "networkx" -version = "3.1" -description = "Python package for creating and manipulating graphs and networks" -optional = true -python-versions = ">=3.8" -files = [ +networkx = [ {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, ] - -[package.extras] -default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] -developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] -doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] -test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] - -[[package]] -name = "numpy" -version = "1.24.3" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "numpy-1.24.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c1104d3c036fb81ab923f507536daedc718d0ad5a8707c6061cdfd6d184e570"}, - {file = "numpy-1.24.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:202de8f38fc4a45a3eea4b63e2f376e5f2dc64ef0fa692838e31a808520efaf7"}, - {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8535303847b89aa6b0f00aa1dc62867b5a32923e4d1681a35b5eef2d9591a463"}, - {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d926b52ba1367f9acb76b0df6ed21f0b16a1ad87c6720a1121674e5cf63e2b6"}, - {file = "numpy-1.24.3-cp310-cp310-win32.whl", hash = "sha256:f21c442fdd2805e91799fbe044a7b999b8571bb0ab0f7850d0cb9641a687092b"}, - {file = "numpy-1.24.3-cp310-cp310-win_amd64.whl", hash = "sha256:ab5f23af8c16022663a652d3b25dcdc272ac3f83c3af4c02eb8b824e6b3ab9d7"}, - {file = "numpy-1.24.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9a7721ec204d3a237225db3e194c25268faf92e19338a35f3a224469cb6039a3"}, - {file = "numpy-1.24.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d6cc757de514c00b24ae8cf5c876af2a7c3df189028d68c0cb4eaa9cd5afc2bf"}, - {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e3f4e85fc5d4fd311f6e9b794d0c00e7002ec122be271f2019d63376f1d385"}, - {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1d3c026f57ceaad42f8231305d4653d5f05dc6332a730ae5c0bea3513de0950"}, - {file = "numpy-1.24.3-cp311-cp311-win32.whl", hash = "sha256:c91c4afd8abc3908e00a44b2672718905b8611503f7ff87390cc0ac3423fb096"}, - {file = "numpy-1.24.3-cp311-cp311-win_amd64.whl", hash = "sha256:5342cf6aad47943286afa6f1609cad9b4266a05e7f2ec408e2cf7aea7ff69d80"}, - {file = "numpy-1.24.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7776ea65423ca6a15255ba1872d82d207bd1e09f6d0894ee4a64678dd2204078"}, - {file = "numpy-1.24.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ae8d0be48d1b6ed82588934aaaa179875e7dc4f3d84da18d7eae6eb3f06c242c"}, - {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecde0f8adef7dfdec993fd54b0f78183051b6580f606111a6d789cd14c61ea0c"}, - {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4749e053a29364d3452c034827102ee100986903263e89884922ef01a0a6fd2f"}, - {file = "numpy-1.24.3-cp38-cp38-win32.whl", hash = "sha256:d933fabd8f6a319e8530d0de4fcc2e6a61917e0b0c271fded460032db42a0fe4"}, - {file = "numpy-1.24.3-cp38-cp38-win_amd64.whl", hash = "sha256:56e48aec79ae238f6e4395886b5eaed058abb7231fb3361ddd7bfdf4eed54289"}, - {file = "numpy-1.24.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4719d5aefb5189f50887773699eaf94e7d1e02bf36c1a9d353d9f46703758ca4"}, - {file = "numpy-1.24.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0ec87a7084caa559c36e0a2309e4ecb1baa03b687201d0a847c8b0ed476a7187"}, - {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8282b9bcfe2b5e7d491d0bf7f3e2da29700cec05b49e64d6246923329f2b02"}, - {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210461d87fb02a84ef243cac5e814aad2b7f4be953b32cb53327bb49fd77fbb4"}, - {file = "numpy-1.24.3-cp39-cp39-win32.whl", hash = "sha256:784c6da1a07818491b0ffd63c6bbe5a33deaa0e25a20e1b3ea20cf0e43f8046c"}, - {file = "numpy-1.24.3-cp39-cp39-win_amd64.whl", hash = "sha256:d5036197ecae68d7f491fcdb4df90082b0d4960ca6599ba2659957aafced7c17"}, - {file = "numpy-1.24.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:352ee00c7f8387b44d19f4cada524586f07379c0d49270f87233983bc5087ca0"}, - {file = "numpy-1.24.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7d6acc2e7524c9955e5c903160aa4ea083736fde7e91276b0e5d98e6332812"}, - {file = "numpy-1.24.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:35400e6a8d102fd07c71ed7dcadd9eb62ee9a6e84ec159bd48c28235bbb0f8e4"}, - {file = "numpy-1.24.3.tar.gz", hash = "sha256:ab344f1bf21f140adab8e47fdbc7c35a477dc01408791f8ba00d018dd0bc5155"}, +numpy = [ + {file = "numpy-1.25.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8aa130c3042052d656751df5e81f6d61edff3e289b5994edcf77f54118a8d9f4"}, + {file = "numpy-1.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e3f2b96e3b63c978bc29daaa3700c028fe3f049ea3031b58aa33fe2a5809d24"}, + {file = "numpy-1.25.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6b267f349a99d3908b56645eebf340cb58f01bd1e773b4eea1a905b3f0e4208"}, + {file = "numpy-1.25.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4aedd08f15d3045a4e9c648f1e04daca2ab1044256959f1f95aafeeb3d794c16"}, + {file = "numpy-1.25.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6d183b5c58513f74225c376643234c369468e02947b47942eacbb23c1671f25d"}, + {file = "numpy-1.25.0-cp310-cp310-win32.whl", hash = "sha256:d76a84998c51b8b68b40448ddd02bd1081bb33abcdc28beee6cd284fe11036c6"}, + {file = "numpy-1.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:c0dc071017bc00abb7d7201bac06fa80333c6314477b3d10b52b58fa6a6e38f6"}, + {file = "numpy-1.25.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c69fe5f05eea336b7a740e114dec995e2f927003c30702d896892403df6dbf0"}, + {file = "numpy-1.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9c7211d7920b97aeca7b3773a6783492b5b93baba39e7c36054f6e749fc7490c"}, + {file = "numpy-1.25.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc68f11404930e9c7ecfc937aa423e1e50158317bf67ca91736a9864eae0232"}, + {file = "numpy-1.25.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e559c6afbca484072a98a51b6fa466aae785cfe89b69e8b856c3191bc8872a82"}, + {file = "numpy-1.25.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6c284907e37f5e04d2412950960894b143a648dea3f79290757eb878b91acbd1"}, + {file = "numpy-1.25.0-cp311-cp311-win32.whl", hash = "sha256:95367ccd88c07af21b379be1725b5322362bb83679d36691f124a16357390153"}, + {file = "numpy-1.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:b76aa836a952059d70a2788a2d98cb2a533ccd46222558b6970348939e55fc24"}, + {file = "numpy-1.25.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b792164e539d99d93e4e5e09ae10f8cbe5466de7d759fc155e075237e0c274e4"}, + {file = "numpy-1.25.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7cd981ccc0afe49b9883f14761bb57c964df71124dcd155b0cba2b591f0d64b9"}, + {file = "numpy-1.25.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aa48bebfb41f93043a796128854b84407d4df730d3fb6e5dc36402f5cd594c0"}, + {file = "numpy-1.25.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5177310ac2e63d6603f659fadc1e7bab33dd5a8db4e0596df34214eeab0fee3b"}, + {file = "numpy-1.25.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0ac6edfb35d2a99aaf102b509c8e9319c499ebd4978df4971b94419a116d0790"}, + {file = "numpy-1.25.0-cp39-cp39-win32.whl", hash = "sha256:7412125b4f18aeddca2ecd7219ea2d2708f697943e6f624be41aa5f8a9852cc4"}, + {file = "numpy-1.25.0-cp39-cp39-win_amd64.whl", hash = "sha256:26815c6c8498dc49d81faa76d61078c4f9f0859ce7817919021b9eba72b425e3"}, + {file = "numpy-1.25.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5b1b90860bf7d8a8c313b372d4f27343a54f415b20fb69dd601b7efe1029c91e"}, + {file = "numpy-1.25.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85cdae87d8c136fd4da4dad1e48064d700f63e923d5af6c8c782ac0df8044542"}, + {file = "numpy-1.25.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cc3fda2b36482891db1060f00f881c77f9423eead4c3579629940a3e12095fe8"}, + {file = "numpy-1.25.0.tar.gz", hash = "sha256:f1accae9a28dc3cda46a91de86acf69de0d1b5f4edd44a9b0c3ceb8036dfff19"}, ] - -[[package]] -name = "opentelemetry-api" -version = "1.15.0" -description = "OpenTelemetry Python API" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-api = [ {file = "opentelemetry_api-1.15.0-py3-none-any.whl", hash = "sha256:e6c2d2e42140fd396e96edf75a7ceb11073f4efb4db87565a431cc9d0f93f2e0"}, {file = "opentelemetry_api-1.15.0.tar.gz", hash = "sha256:79ab791b4aaad27acc3dc3ba01596db5b5aac2ef75c70622c6038051d6c2cded"}, ] - -[package.dependencies] -deprecated = ">=1.2.6" -setuptools = ">=16.0" - -[[package]] -name = "opentelemetry-exporter-otlp" -version = "1.15.0" -description = "OpenTelemetry Collector Exporters" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-exporter-otlp = [ {file = "opentelemetry_exporter_otlp-1.15.0-py3-none-any.whl", hash = "sha256:79f22748b6a54808a0448093dfa189c8490e729f67c134d4c992533d9393b33e"}, {file = "opentelemetry_exporter_otlp-1.15.0.tar.gz", hash = "sha256:4f7c49751d9720e2e726e13b0bb958ccade4e29122c305d92c033da432c8d2c5"}, ] - -[package.dependencies] -opentelemetry-exporter-otlp-proto-grpc = "1.15.0" -opentelemetry-exporter-otlp-proto-http = "1.15.0" - -[[package]] -name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.15.0" -description = "OpenTelemetry Collector Protobuf over gRPC Exporter" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-exporter-otlp-proto-grpc = [ {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0-py3-none-any.whl", hash = "sha256:c2a5492ba7d140109968135d641d06ce3c5bd73c50665f787526065d57d7fd1d"}, {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0.tar.gz", hash = "sha256:844f2a4bb9bcda34e4eb6fe36765e5031aacb36dc60ed88c90fc246942ea26e7"}, ] - -[package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} -googleapis-common-protos = ">=1.52,<2.0" -grpcio = ">=1.0.0,<2.0.0" -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-proto = "1.15.0" -opentelemetry-sdk = ">=1.12,<2.0" - -[package.extras] -test = ["pytest-grpc"] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-http" -version = "1.15.0" -description = "OpenTelemetry Collector Protobuf over HTTP Exporter" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-exporter-otlp-proto-http = [ {file = "opentelemetry_exporter_otlp_proto_http-1.15.0-py3-none-any.whl", hash = "sha256:3ec2a02196c8a54bf5cbf7fe623a5238625638e83b6047a983bdf96e2bbb74c0"}, {file = "opentelemetry_exporter_otlp_proto_http-1.15.0.tar.gz", hash = "sha256:11b2c814249a49b22f6cca7a06b05701f561d577b747f3660dfd67b6eb9daf9c"}, ] - -[package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} -googleapis-common-protos = ">=1.52,<2.0" -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-proto = "1.15.0" -opentelemetry-sdk = ">=1.12,<2.0" -requests = ">=2.7,<3.0" - -[package.extras] -test = ["responses (==0.22.0)"] - -[[package]] -name = "opentelemetry-instrumentation" -version = "0.36b0" -description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-instrumentation = [ {file = "opentelemetry_instrumentation-0.36b0-py3-none-any.whl", hash = "sha256:83ba4ae7d5292b5b33e0f851cc5c76d8f91196b9b3527800fc13855c33383ac2"}, {file = "opentelemetry_instrumentation-0.36b0.tar.gz", hash = "sha256:e3ddac9b3b93408ef26c8ecbf38f717042977e16381bb4cd329a5b4cf16998cf"}, ] - -[package.dependencies] -opentelemetry-api = ">=1.4,<2.0" -setuptools = ">=16.0" -wrapt = ">=1.0.0,<2.0.0" - -[[package]] -name = "opentelemetry-instrumentation-grpc" -version = "0.36b0" -description = "OpenTelemetry gRPC instrumentation" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-instrumentation-grpc = [ {file = "opentelemetry_instrumentation_grpc-0.36b0-py3-none-any.whl", hash = "sha256:eaa246ed2083c97b13bab2555cb9d170e8433230a31476c4cab8a17fa03380a4"}, {file = "opentelemetry_instrumentation_grpc-0.36b0.tar.gz", hash = "sha256:dc89447c9eb6ea868970f6c13b4ffdac182cdd5a41dd215a0f5393ca6375be55"}, ] - -[package.dependencies] -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.36b0" -opentelemetry-sdk = ">=1.12,<2.0" -opentelemetry-semantic-conventions = "0.36b0" -wrapt = ">=1.0.0,<2.0.0" - -[package.extras] -instruments = ["grpcio (>=1.27,<2.0)"] -test = ["opentelemetry-instrumentation-grpc[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.36b0)", "protobuf (>=3.13,<4.0)"] - -[[package]] -name = "opentelemetry-proto" -version = "1.15.0" -description = "OpenTelemetry Python Proto" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-proto = [ {file = "opentelemetry_proto-1.15.0-py3-none-any.whl", hash = "sha256:044b6d044b4d10530f250856f933442b8753a17f94ae37c207607f733fb9a844"}, {file = "opentelemetry_proto-1.15.0.tar.gz", hash = "sha256:9c4008e40ac8cab359daac283fbe7002c5c29c77ea2674ad5626a249e64e0101"}, ] - -[package.dependencies] -protobuf = ">=3.19,<5.0" - -[[package]] -name = "opentelemetry-sdk" -version = "1.15.0" -description = "OpenTelemetry Python SDK" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-sdk = [ {file = "opentelemetry_sdk-1.15.0-py3-none-any.whl", hash = "sha256:555c533e9837766119bbccc7a80458c9971d853a6f1da683a2246cd5e53b4645"}, {file = "opentelemetry_sdk-1.15.0.tar.gz", hash = "sha256:98dbffcfeebcbff12c0c974292d6ea603180a145904cf838b1fe4d5c99078425"}, ] - -[package.dependencies] -opentelemetry-api = "1.15.0" -opentelemetry-semantic-conventions = "0.36b0" -setuptools = ">=16.0" -typing-extensions = ">=3.7.4" - -[[package]] -name = "opentelemetry-semantic-conventions" -version = "0.36b0" -description = "OpenTelemetry Semantic Conventions" -optional = false -python-versions = ">=3.7" -files = [ +opentelemetry-semantic-conventions = [ {file = "opentelemetry_semantic_conventions-0.36b0-py3-none-any.whl", hash = "sha256:adc05635e87b9d3e007c9f530eed487fc3ef2177d02f82f674f28ebf9aff8243"}, {file = "opentelemetry_semantic_conventions-0.36b0.tar.gz", hash = "sha256:829dc221795467d98b773c04096e29be038d77526dc8d6ac76f546fb6279bf01"}, ] - -[[package]] -name = "packaging" -version = "23.1" -description = "Core utilities for Python packages" -optional = false -python-versions = ">=3.7" -files = [ +packaging = [ {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] - -[[package]] -name = "pluggy" -version = "1.0.0" -description = "plugin and hook calling mechanisms for python" -optional = false -python-versions = ">=3.6" -files = [ - {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, - {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, +pluggy = [ + {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, + {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, ] - -[package.extras] -dev = ["pre-commit", "tox"] -testing = ["pytest", "pytest-benchmark"] - -[[package]] -name = "protobuf" -version = "4.23.2" -description = "" -optional = false -python-versions = ">=3.7" -files = [ - {file = "protobuf-4.23.2-cp310-abi3-win32.whl", hash = "sha256:384dd44cb4c43f2ccddd3645389a23ae61aeb8cfa15ca3a0f60e7c3ea09b28b3"}, - {file = "protobuf-4.23.2-cp310-abi3-win_amd64.whl", hash = "sha256:09310bce43353b46d73ba7e3bca78273b9bc50349509b9698e64d288c6372c2a"}, - {file = "protobuf-4.23.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b2cfab63a230b39ae603834718db74ac11e52bccaaf19bf20f5cce1a84cf76df"}, - {file = "protobuf-4.23.2-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:c52cfcbfba8eb791255edd675c1fe6056f723bf832fa67f0442218f8817c076e"}, - {file = "protobuf-4.23.2-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:86df87016d290143c7ce3be3ad52d055714ebaebb57cc659c387e76cfacd81aa"}, - {file = "protobuf-4.23.2-cp37-cp37m-win32.whl", hash = "sha256:281342ea5eb631c86697e1e048cb7e73b8a4e85f3299a128c116f05f5c668f8f"}, - {file = "protobuf-4.23.2-cp37-cp37m-win_amd64.whl", hash = "sha256:ce744938406de1e64b91410f473736e815f28c3b71201302612a68bf01517fea"}, - {file = "protobuf-4.23.2-cp38-cp38-win32.whl", hash = "sha256:6c081863c379bb1741be8f8193e893511312b1d7329b4a75445d1ea9955be69e"}, - {file = "protobuf-4.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:25e3370eda26469b58b602e29dff069cfaae8eaa0ef4550039cc5ef8dc004511"}, - {file = "protobuf-4.23.2-cp39-cp39-win32.whl", hash = "sha256:efabbbbac1ab519a514579ba9ec52f006c28ae19d97915951f69fa70da2c9e91"}, - {file = "protobuf-4.23.2-cp39-cp39-win_amd64.whl", hash = "sha256:54a533b971288af3b9926e53850c7eb186886c0c84e61daa8444385a4720297f"}, - {file = "protobuf-4.23.2-py3-none-any.whl", hash = "sha256:8da6070310d634c99c0db7df48f10da495cc283fd9e9234877f0cd182d43ab7f"}, - {file = "protobuf-4.23.2.tar.gz", hash = "sha256:20874e7ca4436f683b64ebdbee2129a5a2c301579a67d1a7dda2cdf62fb7f5f7"}, +protobuf = [ + {file = "protobuf-4.23.3-cp310-abi3-win32.whl", hash = "sha256:514b6bbd54a41ca50c86dd5ad6488afe9505901b3557c5e0f7823a0cf67106fb"}, + {file = "protobuf-4.23.3-cp310-abi3-win_amd64.whl", hash = "sha256:cc14358a8742c4e06b1bfe4be1afbdf5c9f6bd094dff3e14edb78a1513893ff5"}, + {file = "protobuf-4.23.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:2991f5e7690dab569f8f81702e6700e7364cc3b5e572725098215d3da5ccc6ac"}, + {file = "protobuf-4.23.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:08fe19d267608d438aa37019236db02b306e33f6b9902c3163838b8e75970223"}, + {file = "protobuf-4.23.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3b01a5274ac920feb75d0b372d901524f7e3ad39c63b1a2d55043f3887afe0c1"}, + {file = "protobuf-4.23.3-cp37-cp37m-win32.whl", hash = "sha256:aca6e86a08c5c5962f55eac9b5bd6fce6ed98645d77e8bfc2b952ecd4a8e4f6a"}, + {file = "protobuf-4.23.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0149053336a466e3e0b040e54d0b615fc71de86da66791c592cc3c8d18150bf8"}, + {file = "protobuf-4.23.3-cp38-cp38-win32.whl", hash = "sha256:84ea0bd90c2fdd70ddd9f3d3fc0197cc24ecec1345856c2b5ba70e4d99815359"}, + {file = "protobuf-4.23.3-cp38-cp38-win_amd64.whl", hash = "sha256:3bcbeb2bf4bb61fe960dd6e005801a23a43578200ea8ceb726d1f6bd0e562ba1"}, + {file = "protobuf-4.23.3-cp39-cp39-win32.whl", hash = "sha256:5cb9e41188737f321f4fce9a4337bf40a5414b8d03227e1d9fbc59bc3a216e35"}, + {file = "protobuf-4.23.3-cp39-cp39-win_amd64.whl", hash = "sha256:29660574cd769f2324a57fb78127cda59327eb6664381ecfe1c69731b83e8288"}, + {file = "protobuf-4.23.3-py3-none-any.whl", hash = "sha256:447b9786ac8e50ae72cae7a2eec5c5df6a9dbf9aa6f908f1b8bda6032644ea62"}, + {file = "protobuf-4.23.3.tar.gz", hash = "sha256:7a92beb30600332a52cdadbedb40d33fd7c8a0d7f549c440347bc606fb3fe34b"}, ] - -[[package]] -name = "psutil" -version = "5.9.5" -description = "Cross-platform lib for process and system monitoring in Python." -optional = true -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ +psutil = [ {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, @@ -914,39 +1247,11 @@ files = [ {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, ] - -[package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] - -[[package]] -name = "pytest" -version = "7.3.2" -description = "pytest: simple powerful testing with Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "pytest-7.3.2-py3-none-any.whl", hash = "sha256:cdcbd012c9312258922f8cd3f1b62a6580fdced17db6014896053d47cddf9295"}, - {file = "pytest-7.3.2.tar.gz", hash = "sha256:ee990a3cc55ba808b80795a79944756f315c67c12b56abd3ac993a7b8c17030b"}, +pytest = [ + {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, + {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, ] - -[package.dependencies] -colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} -iniconfig = "*" -packaging = "*" -pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} - -[package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] - -[[package]] -name = "pyyaml" -version = "6.0" -description = "YAML parser and emitter for Python" -optional = false -python-versions = ">=3.6" -files = [ +PyYAML = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, @@ -988,14 +1293,7 @@ files = [ {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] - -[[package]] -name = "regex" -version = "2023.6.3" -description = "Alternative regular expression module, to replace re." -optional = false -python-versions = ">=3.6" -files = [ +regex = [ {file = "regex-2023.6.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:824bf3ac11001849aec3fa1d69abcb67aac3e150a933963fb12bda5151fe1bfd"}, {file = "regex-2023.6.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:05ed27acdf4465c95826962528f9e8d41dbf9b1aa8531a387dee6ed215a3e9ef"}, {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b49c764f88a79160fa64f9a7b425620e87c9f46095ef9c9920542ab2495c8bc"}, @@ -1085,35 +1383,11 @@ files = [ {file = "regex-2023.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:bdff5eab10e59cf26bc479f565e25ed71a7d041d1ded04ccf9aee1d9f208487a"}, {file = "regex-2023.6.3.tar.gz", hash = "sha256:72d1a25bf36d2050ceb35b517afe13864865268dfb45910e2e17a84be6cbfeb0"}, ] - -[[package]] -name = "requests" -version = "2.31.0" -description = "Python HTTP for Humans." -optional = false -python-versions = ">=3.7" -files = [ +requests = [ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, ] - -[package.dependencies] -certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" -idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<3" - -[package.extras] -socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] - -[[package]] -name = "safetensors" -version = "0.3.1" -description = "Fast and Safe Tensor serialization" -optional = false -python-versions = "*" -files = [ +safetensors = [ {file = "safetensors-0.3.1-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:2ae9b7dd268b4bae6624729dac86deb82104820e9786429b0583e5168db2f770"}, {file = "safetensors-0.3.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:08c85c1934682f1e2cd904d38433b53cd2a98245a7cc31f5689f9322a2320bbf"}, {file = "safetensors-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba625c7af9e1c5d0d91cb83d2fba97d29ea69d4db2015d9714d24c7f6d488e15"}, @@ -1155,25 +1429,7 @@ files = [ {file = "safetensors-0.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f4f614b8e8161cd8a9ca19c765d176a82b122fa3d3387b77862145bfe9b4e93"}, {file = "safetensors-0.3.1.tar.gz", hash = "sha256:571da56ff8d0bec8ae54923b621cda98d36dcef10feb36fd492c4d0c2cd0e869"}, ] - -[package.extras] -all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] -dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] -jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)"] -numpy = ["numpy (>=1.21.6)"] -paddlepaddle = ["paddlepaddle (>=2.4.1)"] -quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] -tensorflow = ["tensorflow (>=2.11.0)"] -testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] -torch = ["torch (>=1.10)"] - -[[package]] -name = "sentencepiece" -version = "0.1.99" -description = "SentencePiece python wrapper" -optional = false -python-versions = "*" -files = [ +sentencepiece = [ {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73"}, {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7"}, {file = "sentencepiece-0.1.99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0"}, @@ -1220,44 +1476,15 @@ files = [ {file = "sentencepiece-0.1.99-cp39-cp39-win_amd64.whl", hash = "sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd"}, {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, ] - -[[package]] -name = "setuptools" -version = "67.8.0" -description = "Easily download, build, install, upgrade, and uninstall Python packages" -optional = false -python-versions = ">=3.7" -files = [ - {file = "setuptools-67.8.0-py3-none-any.whl", hash = "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f"}, - {file = "setuptools-67.8.0.tar.gz", hash = "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102"}, +setuptools = [ + {file = "setuptools-68.0.0-py3-none-any.whl", hash = "sha256:11e52c67415a381d10d6b462ced9cfb97066179f0e871399e006c4ab101fc85f"}, + {file = "setuptools-68.0.0.tar.gz", hash = "sha256:baf1fdb41c6da4cd2eae722e135500da913332ab3f2f5c7d33af9b492acb5235"}, ] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] - -[[package]] -name = "sympy" -version = "1.12" -description = "Computer algebra system (CAS) in Python" -optional = true -python-versions = ">=3.8" -files = [ +sympy = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] - -[package.dependencies] -mpmath = ">=0.19" - -[[package]] -name = "tokenizers" -version = "0.13.3" -description = "Fast and Customizable Tokenizers" -optional = false -python-versions = "*" -files = [ +tokenizers = [ {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, @@ -1299,30 +1526,11 @@ files = [ {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, ] - -[package.extras] -dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] -testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] - -[[package]] -name = "tomli" -version = "2.0.1" -description = "A lil' TOML parser" -optional = false -python-versions = ">=3.7" -files = [ +tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] - -[[package]] -name = "torch" -version = "2.0.1" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -optional = true -python-versions = ">=3.8.0" -files = [ +torch = [ {file = "torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ced00b3ba471856b993822508f77c98f48a458623596a4c43136158781e306a"}, {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"}, {file = "torch-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7c84e44d9002182edd859f3400deaa7410f5ec948a519cc7ef512c2f9b34d2c4"}, @@ -1344,175 +1552,31 @@ files = [ {file = "torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c62df99352bd6ee5a5a8d1832452110435d178b5164de450831a3a8cc14dc680"}, {file = "torch-2.0.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:671a2565e3f63b8fe8e42ae3e36ad249fe5e567435ea27b94edaa672a7d0c416"}, ] - -[package.dependencies] -filelock = "*" -jinja2 = "*" -networkx = "*" -sympy = "*" -typing-extensions = "*" - -[package.extras] -opt-einsum = ["opt-einsum (>=3.3)"] - -[[package]] -name = "tqdm" -version = "4.65.0" -description = "Fast, Extensible Progress Meter" -optional = false -python-versions = ">=3.7" -files = [ +tqdm = [ {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, ] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -dev = ["py-make (>=0.1.0)", "twine", "wheel"] -notebook = ["ipywidgets (>=6)"] -slack = ["slack-sdk"] -telegram = ["requests"] - -[[package]] -name = "transformers" -version = "4.30.2" -description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "transformers-4.30.2-py3-none-any.whl", hash = "sha256:c332e3a3097f9ed89ce556b403251235931c00237b8bc2d7adaa19d226c13f1d"}, - {file = "transformers-4.30.2.tar.gz", hash = "sha256:f4a8aac4e1baffab4033f4a345b0d7dc7957d12a4f1ba969afea08205a513045"}, +transformers = [ + {file = "transformers-4.29.2-py3-none-any.whl", hash = "sha256:0ef158b99bad6f4e6652a0d8655fbbe58b4cb788ce7040f320b5d29c7c810a75"}, + {file = "transformers-4.29.2.tar.gz", hash = "sha256:ed9467661f459f1ce49461d83f18f3b36b6a37f306182dc2ba272935f3b93ebb"}, ] - -[package.dependencies] -filelock = "*" -huggingface-hub = ">=0.14.1,<1.0" -numpy = ">=1.17" -packaging = ">=20.0" -pyyaml = ">=5.1" -regex = "!=2019.12.17" -requests = "*" -safetensors = ">=0.3.1" -tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" -tqdm = ">=4.27" - -[package.extras] -accelerate = ["accelerate (>=0.20.2)"] -agents = ["Pillow", "accelerate (>=0.20.2)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] -all = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] -audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.20.2)", "deepspeed (>=0.8.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.8.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.6.9)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] -docs-specific = ["hf-doc-builder"] -fairscale = ["fairscale (>0.3)"] -flax = ["flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8,<=0.1.4)"] -flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune]", "sigopt"] -ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] -modelcreation = ["cookiecutter (==1.7.3)"] -natten = ["natten (>=0.14.6)"] -onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] -onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] -optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] -ray = ["ray[tune]"] -retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] -sagemaker = ["sagemaker (>=2.31.0)"] -sentencepiece = ["protobuf (<=3.20.3)", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic", "starlette", "uvicorn"] -sigopt = ["sigopt"] -sklearn = ["scikit-learn"] -speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] -tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -timm = ["timm"] -tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] -torch = ["accelerate (>=0.20.2)", "torch (>=1.9,!=1.12.0)"] -torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.14.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.3)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] -video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow"] - -[[package]] -name = "typer" -version = "0.6.1" -description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -optional = false -python-versions = ">=3.6" -files = [ +typer = [ {file = "typer-0.6.1-py3-none-any.whl", hash = "sha256:54b19e5df18654070a82f8c2aa1da456a4ac16a2a83e6dcd9f170e291c56338e"}, {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, ] - -[package.dependencies] -click = ">=7.1.1,<9.0.0" - -[package.extras] -all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] -dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] -doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] -test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] - -[[package]] -name = "typing-extensions" -version = "4.6.3" -description = "Backported and Experimental Type Hints for Python 3.7+" -optional = false -python-versions = ">=3.7" -files = [ - {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"}, - {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, +typing-extensions = [ + {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, + {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, ] - -[[package]] -name = "urllib3" -version = "2.0.3" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=3.7" -files = [ +urllib3 = [ {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"}, {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"}, ] - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] - -[[package]] -name = "win32-setctime" -version = "1.1.0" -description = "A small Python utility to set file creation time on Windows" -optional = false -python-versions = ">=3.5" -files = [ +win32-setctime = [ {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, ] - -[package.extras] -dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] - -[[package]] -name = "wrapt" -version = "1.15.0" -description = "Module for decorators, wrappers and monkey patching." -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" -files = [ +wrapt = [ {file = "wrapt-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2"}, @@ -1589,12 +1653,3 @@ files = [ {file = "wrapt-1.15.0-py3-none-any.whl", hash = "sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640"}, {file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"}, ] - -[extras] -accelerate = ["accelerate"] -bnb = ["bitsandbytes"] - -[metadata] -lock-version = "2.0" -python-versions = "^3.9" -content-hash = "3174a211d30bed5990ed5f8418416c951bb6c585153fb51b62809baa89ef07d0" diff --git a/server/pyproject.toml b/server/pyproject.toml index bbf5836d..fd640e7f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "0.13.3" huggingface-hub = "^0.14.1" -transformers = "^4.29.2" +transformers = "4.29.2" einops = "^0.6.1" [tool.poetry.extras] diff --git a/server/requirements.txt b/server/requirements.txt index 92693bbd..9b8b2164 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,22 +1,23 @@ backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0" +bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0" certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0" charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") +colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0" filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0" googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0" grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0" +grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0" +grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0" +grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" idna==3.4 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" -numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0" +numpy==1.25.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0" @@ -27,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0" packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" -protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0" +protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0" pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0" regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0" requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" -setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" +setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" -transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0" +transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" -typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0" +typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0" urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 07765e88..d224a838 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -135,8 +135,7 @@ class FlashLlamaAttention(torch.nn.Module): hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -158,17 +157,15 @@ class FlashLlamaAttention(torch.nn.Module): attn_output = torch.empty_like(qkv[:, 0]) # Prefill - if start_seq_prefill is not None: + if cu_seqlen_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq_prefill, - end_seq_prefill, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, + cu_seqlen_prefill, max_s, max_s, 0.0, @@ -261,8 +258,7 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -276,8 +272,7 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -329,8 +324,7 @@ class FlashLlamaModel(torch.nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -352,8 +346,7 @@ class FlashLlamaModel(torch.nn.Module): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache[i], block_tables, slots, @@ -381,8 +374,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -393,8 +385,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): hidden_states = self.model( input_ids, position_ids, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9049878a..23c5e4ff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -123,8 +123,7 @@ class FlashNeoxAttention(torch.nn.Module): hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -146,17 +145,15 @@ class FlashNeoxAttention(torch.nn.Module): attn_output = torch.empty_like(qkv[:, 0]) # Prefill - if start_seq_prefill is not None: + if cu_seqlen_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq_prefill, - end_seq_prefill, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, + cu_seqlen_prefill, max_s, max_s, 0.0, @@ -246,8 +243,7 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -261,8 +257,7 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -286,8 +281,7 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -341,8 +335,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -364,8 +357,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache[i], block_tables, slots, @@ -391,8 +383,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -403,8 +394,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): hidden_states = self.gpt_neox( input_ids, position_ids, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 44aa7cb1..cd42bfc2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -144,8 +144,7 @@ class FlashRWAttention(torch.nn.Module): hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -176,7 +175,7 @@ class FlashRWAttention(torch.nn.Module): attn_output = torch.empty_like(query) # Prefill - if start_seq_prefill is not None: + if cu_seqlen_prefill is not None: if self.num_heads_kv == 1: # Expand to query shape kv = kv.expand(-1, 2, self.num_heads, self.head_size) @@ -187,10 +186,8 @@ class FlashRWAttention(torch.nn.Module): torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, - start_seq_prefill, - end_seq_prefill, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, + cu_seqlen_prefill, max_s, max_s, 0.0, @@ -276,8 +273,7 @@ class FlashRWLargeAttention(torch.nn.Module): hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -311,7 +307,7 @@ class FlashRWLargeAttention(torch.nn.Module): attn_output = torch.empty_like(query) # Prefill - if start_seq_prefill is not None: + if cu_seqlen_prefill is not None: # Expand to query shape kv = ( kv.unsqueeze(2) @@ -325,10 +321,8 @@ class FlashRWLargeAttention(torch.nn.Module): torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, - start_seq_prefill, - end_seq_prefill, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, + cu_seqlen_prefill, max_s, max_s, 0.0, @@ -428,8 +422,7 @@ class FlashRWLayer(nn.Module): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -443,8 +436,7 @@ class FlashRWLayer(nn.Module): ln_hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -466,8 +458,7 @@ class FlashRWLayer(nn.Module): hidden_states, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -516,8 +507,7 @@ class FlashRWLargeLayer(nn.Module): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -532,8 +522,7 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -597,8 +586,7 @@ class FlashRWModel(FlashRWPreTrainedModel): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -620,8 +608,7 @@ class FlashRWModel(FlashRWPreTrainedModel): residual, cos, sin, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache[i], block_tables, slots, @@ -648,8 +635,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -660,8 +646,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): hidden_states = self.transformer( input_ids, position_ids, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 04eedef7..855a6e11 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -232,8 +232,7 @@ class FlashMQAttention(torch.nn.Module): def forward( self, hidden_states, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -259,7 +258,7 @@ class FlashMQAttention(torch.nn.Module): attn_output = torch.empty_like(query) # Prefill - if start_seq_prefill is not None: + if cu_seqlen_prefill is not None: # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) @@ -269,10 +268,8 @@ class FlashMQAttention(torch.nn.Module): torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, - start_seq_prefill, - end_seq_prefill, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, + cu_seqlen_prefill, max_s, max_s, 0.0, @@ -357,8 +354,7 @@ class Block(nn.Module): self, hidden_states, residual, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -369,8 +365,7 @@ class Block(nn.Module): hidden_states = self.attn( hidden_states, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, @@ -423,8 +418,7 @@ class FlashSantacoderModel(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -441,8 +435,7 @@ class FlashSantacoderModel(nn.Module): hidden_states, residual = layer( hidden_states, residual, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache[i], block_tables, slots, @@ -467,8 +460,7 @@ class FlashSantacoderForCausalLM(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, @@ -479,8 +471,7 @@ class FlashSantacoderForCausalLM(nn.Module): hidden_states = self.transformer( input_ids, position_ids, - start_seq_prefill, - end_seq_prefill, + cu_seqlen_prefill, kv_cache, block_tables, slots, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 94b14f85..bf5f5bbe 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -121,10 +121,10 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor - # tensor of length b holding starting offset of each sequence, only used in prefill - start_seq_prefill: Optional[torch.Tensor] - # tensor of length b holding ending offset of each sequence, only used in prefill - end_seq_prefill: Optional[torch.Tensor] + # Flash Attention values + + # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] # Paged Attention values @@ -197,8 +197,7 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] - start_seq_prefill = [] - end_seq_prefill = [] + cu_seqlen_prefill = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] @@ -250,8 +249,7 @@ class FlashCausalLMBatch(Batch): position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - start_seq_prefill.append(cumulative_length) - end_seq_prefill.append(cumulative_length + input_length) + cu_seqlen_prefill.append(cumulative_length + input_length) next_token_chooser_parameters.append(r.parameters) @@ -329,11 +327,8 @@ class FlashCausalLMBatch(Batch): position_ids = position_ids[0] slot_indices = slot_indices[0] - start_seq_prefill = torch.tensor( - start_seq_prefill, device=device, dtype=torch.int32 - ) - end_seq_prefill = torch.tensor( - end_seq_prefill, device=device, dtype=torch.int32 + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 ) position_ids = position_ids.to(device) @@ -345,9 +340,9 @@ class FlashCausalLMBatch(Batch): if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = end_seq_prefill - 1 + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: - prefill_head_indices = end_seq_prefill - 1 + prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( @@ -363,8 +358,7 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - start_seq_prefill=start_seq_prefill, - end_seq_prefill=end_seq_prefill, + cu_seqlen_prefill=cu_seqlen_prefill, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, @@ -504,8 +498,7 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - start_seq_prefill=None, - end_seq_prefill=None, + cu_seqlen_prefill=None, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=None, @@ -652,8 +645,7 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - start_seq_prefill=None, - end_seq_prefill=None, + cu_seqlen_prefill=None, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=None, @@ -750,8 +742,7 @@ class FlashCausalLM(Model): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq_prefill: Optional[torch.Tensor], - end_seq_prefill: Optional[torch.Tensor], + cu_seqlen_prefill: Optional[torch.Tensor], block_tables: torch.Tensor, slots: torch.Tensor, input_lengths: torch.Tensor, @@ -764,8 +755,7 @@ class FlashCausalLM(Model): return self.model.forward( input_ids=input_ids, position_ids=position_ids, - start_seq_prefill=start_seq_prefill, - end_seq_prefill=end_seq_prefill, + cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=CACHE_MANAGER.kv_cache, block_tables=block_tables, slots=slots, @@ -778,7 +768,7 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - prefill = batch.start_seq_prefill is not None + prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None if batch.needed_blocks_slots: @@ -788,8 +778,7 @@ class FlashCausalLM(Model): out = self.forward( batch.input_ids, batch.position_ids, - batch.start_seq_prefill, - batch.end_seq_prefill, + batch.cu_seqlen_prefill, batch.block_tables_tensor, batch.slots[batch.slot_indices], batch.input_lengths_tensor, @@ -815,10 +804,9 @@ class FlashCausalLM(Model): prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1] - # We do not need start_seq_prefill and end_seq_prefill anymore - batch.start_seq_prefill = None - batch.end_seq_prefill = None + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + # We do not need cu_seqlen_prefill anymore + batch.cu_seqlen_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 3c0f8167..a4fe5105 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -66,7 +66,9 @@ class MPTSharded(CausalLM): if local_path.exists(): filename = str(local_path.resolve()) else: - filename = hf_hub_download(model_id, revision=revision, filename="config.json") + filename = hf_hub_download( + model_id, revision=revision, filename="config.json" + ) with open(filename, "r") as f: config = json.load(f) config = PretrainedConfig(**config) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index cbdfea66..8e0362b8 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -359,7 +359,7 @@ try: def __init__(self, inv_freq): super().__init__() - self.register_buffer("inv_freq", inv_freq) + self.inv_freq = inv_freq self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None