From 6bcad66c6e6f1af7f26f5a2c8314f139c7038bdb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 9 Aug 2024 12:31:08 +0200 Subject: [PATCH] Using an enum for flash backens (paged/flashdecoding/flashinfer) --- backends/v3/src/backend.rs | 10 +++---- launcher/src/main.rs | 2 +- router/src/infer/v2/scheduler.rs | 11 +++---- router/src/lib.rs | 29 +++++++++++++++++++ .../layers/attention/common.py | 4 +-- .../layers/attention/cuda.py | 11 ++++--- .../layers/attention/rocm.py | 4 +-- .../models/flash_causal_lm.py | 11 ++++--- .../text_generation_server/models/globals.py | 10 ++----- 9 files changed, 58 insertions(+), 34 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 6b3e0526..5c6c2a70 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token, Attention}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -35,12 +35,12 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention.parse().expect(&format!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged }; - let block_size = if flashdecoding { 256 } else { 16 }; + let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8acfda0c..a64b1d71 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1461,7 +1461,7 @@ fn main() -> Result<(), LauncherError> { if config.model_type == Some("gemma2".to_string()) { tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("FLASH_DECODING", "1"); + std::env::set_var("ATTENTION", "flashdecoding"); } let config: Config = config.into(); diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index cc333674..7a93338b 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -2,9 +2,10 @@ use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, + Attention, }; use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; +use crate::{FinishReason, PrefillToken, Token, Attention}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -40,12 +41,12 @@ impl BackendV2 { generation_health: Arc, ) -> Self { // Infer shared state - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention.parse().expect(&format!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged }; - let block_size = if flashdecoding { 256 } else { 16 }; + let block_size = if attention == Attention::FlashDecoding { 256 } else { 16 }; let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/router/src/lib.rs b/router/src/lib.rs index a956b058..d1c3b25e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -15,6 +15,35 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(PartialEq)] +pub enum Attention{ + Paged, + FlashDecoding, + FlashInfer, +} + +#[derive(Debug)] +pub struct ParseError; + +impl std::fmt::Display for ParseError{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cannot parse attention value") + } +} +impl std::error::Error for ParseError{} + +impl std::str::FromStr for Attention{ + type Err = ParseError; + fn from_str(s: &str) -> Result{ + match s{ + "paged" => Ok(Attention::Paged), + "flashdecoding" => Ok(Attention::FlashDecoding), + "flashinfer" => Ok(Attention::FlashInfer), + _ => Err(ParseError) + } + } +} + #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")] diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index b986a082..f162230c 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER +from text_generation_server.models.globals import ATTENTION import torch from typing import Optional -if FLASH_DECODING or FLASH_INFER: +if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 1b8e9209..d039e1e7 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,9 +1,8 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( - FLASH_DECODING, + ATTENTION, BLOCK_SIZE, - FLASH_INFER, ) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -27,7 +26,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -76,7 +75,7 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import decode_state return decode_state.get().forward( @@ -85,7 +84,7 @@ def paged_attention( logits_soft_cap=softcap, sm_scale=softmax_scale, ) - elif FLASH_DECODING: + elif ATTENTION == "flashdecoding": max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -219,7 +218,7 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if FLASH_INFER: +if ATTENTION == "flashinfer": def attention( q, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69e64162..16ce8d2b 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import ATTENTION from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master from loguru import logger @@ -28,7 +28,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if ATTENTION == "flashdecoding": shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 12aa7dcd..21b66a68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -40,8 +40,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, - FLASH_DECODING, - FLASH_INFER, + ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -938,7 +937,7 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_prefill_state, create_decode_state, @@ -990,7 +989,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: self.kv_cache = [ ( torch.empty( @@ -1062,7 +1061,7 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_decode_state_cuda_graphs, ) @@ -1766,7 +1765,7 @@ class FlashCausalLM(Model): input_lengths: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: - if not FLASH_INFER: + if ATTENTION != "flashinfer": return nullcontext() from text_generation_server.layers.attention.flash_infer import ( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 42b43c87..7e22a56c 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,12 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} -if FLASH_INFER: - log_master(logger.info, "Using FLASH_INFER") +ATTENTION = os.getenv("ATTENTION", "paged") +log_master(logger.info, f"Using Attention = {ATTENTION}") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli -FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 -if FLASH_DECODING: - log_master(logger.info, "Using FLASH_DECODING") +BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 cuda_graphs = os.getenv("CUDA_GRAPHS")