mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Fix clippy and fmt.
This commit is contained in:
parent
379e1659a9
commit
a4b1806557
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token, Attention};
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
@ -36,11 +36,17 @@ impl BackendV3 {
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention.parse().unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 };
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
|
@ -1,11 +1,10 @@
|
||||
/// Batching and inference logic
|
||||
use crate::infer::v2::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
Attention,
|
||||
Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token, Attention};
|
||||
use crate::{Attention, FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
@ -42,11 +41,17 @@ impl BackendV2 {
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention.parse().expect(&format!("Invalid attention was specified :`{attention}`"))
|
||||
attention
|
||||
.parse()
|
||||
.expect(&format!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 };
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
@ -39,7 +39,7 @@ impl std::str::FromStr for Attention{
|
||||
"paged" => Ok(Attention::Paged),
|
||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||
"flashinfer" => Ok(Attention::FlashInfer),
|
||||
_ => Err(ParseError)
|
||||
_ => Err(ParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,9 @@ from text_generation_server.utils.log import log_master
|
||||
|
||||
ATTENTION = os.getenv("ATTENTION", "paged")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert ATTENTION in _expected, f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
Loading…
Reference in New Issue
Block a user