Fix clippy and fmt.

This commit is contained in:
Nicolas Patry 2024-08-09 14:54:52 +02:00
parent 379e1659a9
commit a4b1806557
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
6 changed files with 33 additions and 20 deletions

View File

@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token, Attention}; use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
@ -36,11 +36,17 @@ impl BackendV3 {
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention.parse().unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else { } else {
Attention::Paged Attention::Paged
}; };
let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 }; let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
};
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,

View File

@ -1,11 +1,10 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
Attention,
}; };
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token, Attention}; use crate::{Attention, FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -42,11 +41,17 @@ impl BackendV2 {
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention.parse().expect(&format!("Invalid attention was specified :`{attention}`")) attention
.parse()
.expect(&format!("Invalid attention was specified :`{attention}`"))
} else { } else {
Attention::Paged Attention::Paged
}; };
let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 }; let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());

View File

@ -16,7 +16,7 @@ use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(PartialEq)] #[derive(PartialEq)]
pub enum Attention{ pub enum Attention {
Paged, Paged,
FlashDecoding, FlashDecoding,
FlashInfer, FlashInfer,
@ -25,21 +25,21 @@ pub enum Attention{
#[derive(Debug)] #[derive(Debug)]
pub struct ParseError; pub struct ParseError;
impl std::fmt::Display for ParseError{ impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value") write!(f, "Cannot parse attention value")
} }
} }
impl std::error::Error for ParseError{} impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention{ impl std::str::FromStr for Attention {
type Err = ParseError; type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err>{ fn from_str(s: &str) -> Result<Self, Self::Err> {
match s{ match s {
"paged" => Ok(Attention::Paged), "paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding), "flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer), "flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError) _ => Err(ParseError),
} }
} }
} }

View File

@ -6,8 +6,10 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
ATTENTION = os.getenv("ATTENTION", "paged") ATTENTION = os.getenv("ATTENTION", "paged")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ATTENTION in _expected, f"Attention is not valid {ATTENTION}, expected {_expected}" assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None