diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index e744e0d0..68ddf00b 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{FinishReason, PrefillToken, Token, Attention}; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -36,11 +36,17 @@ impl BackendV3 { speculate: u32, ) -> Self { let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention.parse().unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) } else { Attention::Paged }; - let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 + }; let queue = Queue::new( requires_padding, diff --git a/docs/openapi.json b/docs/openapi.json index 9d281a48..ed9b0b96 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +} diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 7507687f..2b9d7f99 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants: 2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) 3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) -For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. ## Quantization with bitsandbytes @@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 7a93338b..0e5fc8a3 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,11 +1,10 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, - Attention, + Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token, Attention}; +use crate::{Attention, FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -42,11 +41,17 @@ impl BackendV2 { ) -> Self { // Infer shared state let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention.parse().expect(&format!("Invalid attention was specified :`{attention}`")) + attention + .parse() + .expect(&format!("Invalid attention was specified :`{attention}`")) } else { Attention::Paged }; - let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 + }; let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/router/src/lib.rs b/router/src/lib.rs index d1c3b25e..66738706 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -16,7 +16,7 @@ use utoipa::ToSchema; use validation::Validation; #[derive(PartialEq)] -pub enum Attention{ +pub enum Attention { Paged, FlashDecoding, FlashInfer, @@ -25,21 +25,21 @@ pub enum Attention{ #[derive(Debug)] pub struct ParseError; -impl std::fmt::Display for ParseError{ +impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Cannot parse attention value") } } -impl std::error::Error for ParseError{} +impl std::error::Error for ParseError {} -impl std::str::FromStr for Attention{ +impl std::str::FromStr for Attention { type Err = ParseError; - fn from_str(s: &str) -> Result{ - match s{ + fn from_str(s: &str) -> Result { + match s { "paged" => Ok(Attention::Paged), "flashdecoding" => Ok(Attention::FlashDecoding), "flashinfer" => Ok(Attention::FlashInfer), - _ => Err(ParseError) + _ => Err(ParseError), } } } diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ecd1ab32..b58a5b80 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -6,8 +6,10 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master ATTENTION = os.getenv("ATTENTION", "paged") -_expected = {"paged", "flashdecoding", "flashinfer"} -assert ATTENTION in _expected, f"Attention is not valid {ATTENTION}, expected {_expected}" +_expected = {"paged", "flashdecoding", "flashinfer"} +assert ( + ATTENTION in _expected +), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None