diff --git a/router/client/src/pb/generate.v2.rs b/router/client/src/pb/generate.v2.rs new file mode 100644 index 00000000..1a206360 --- /dev/null +++ b/router/client/src/pb/generate.v2.rs @@ -0,0 +1,647 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthResponse {} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoResponse { + #[prost(bool, tag = "1")] + pub requires_padding: bool, + #[prost(string, tag = "2")] + pub dtype: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub device_type: ::prost::alloc::string::String, + #[prost(uint32, optional, tag = "4")] + pub window_size: ::core::option::Option, + #[prost(uint32, tag = "5")] + pub speculate: u32, +} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryResponse { + /// / Other shards urls + #[prost(string, repeated, tag = "1")] + pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheRequest { + /// / Optional batch id + #[prost(uint64, optional, tag = "1")] + pub id: ::core::option::Option, +} +/// / Empty response +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheResponse {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NextTokenChooserParameters { + /// / exponential scaling output probability distribution + #[prost(float, tag = "1")] + pub temperature: f32, + /// / restricting to the k highest probability elements + #[prost(uint32, tag = "2")] + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "3")] + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "4")] + pub typical_p: f32, + /// / apply sampling on the logits + #[prost(bool, tag = "5")] + pub do_sample: bool, + /// / random seed for sampling + #[prost(uint64, tag = "6")] + pub seed: u64, + /// / repetition penalty + #[prost(float, tag = "7")] + pub repetition_penalty: f32, + /// / frequency penalty + #[prost(float, tag = "9")] + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + #[prost(bool, tag = "8")] + pub watermark: bool, + /// / grammar (applied if not empty) + #[prost(string, tag = "10")] + pub grammar: ::prost::alloc::string::String, + /// / grammar type + #[prost(enumeration = "GrammarType", tag = "11")] + pub grammar_type: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StoppingCriteriaParameters { + /// / Maximum number of generated tokens + #[prost(uint32, tag = "1")] + pub max_new_tokens: u32, + /// / Optional stopping sequences + #[prost(string, repeated, tag = "2")] + pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / Ignore end of sequence token + /// / used for benchmarking + #[prost(bool, tag = "3")] + pub ignore_eos_token: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Request { + /// / Request ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / The generation context + #[prost(string, tag = "2")] + pub inputs: ::prost::alloc::string::String, + /// / Context truncation + #[prost(uint32, tag = "3")] + pub truncate: u32, + /// / Next Token Chooser Parameters + #[prost(message, optional, tag = "4")] + pub parameters: ::core::option::Option, + /// / Stopping Criteria Parameters + #[prost(message, optional, tag = "5")] + pub stopping_parameters: ::core::option::Option, + /// / Return prefill logprobs + #[prost(bool, tag = "6")] + pub prefill_logprobs: bool, + /// / Return most likely n tokens + #[prost(uint32, tag = "7")] + pub top_n_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Batch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests + #[prost(message, repeated, tag = "2")] + pub requests: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CachedBatch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests ids + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GeneratedText { + /// / Output + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + /// / Number of generated tokens + #[prost(uint32, tag = "2")] + pub generated_tokens: u32, + /// / Finish reason + #[prost(enumeration = "FinishReason", tag = "3")] + pub finish_reason: i32, + /// / Seed + #[prost(uint64, optional, tag = "4")] + pub seed: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Tokens { + /// / Token IDs + #[prost(uint32, repeated, tag = "1")] + pub ids: ::prost::alloc::vec::Vec, + /// / Logprobs + #[prost(float, repeated, tag = "2")] + pub logprobs: ::prost::alloc::vec::Vec, + /// / tokens + #[prost(string, repeated, tag = "3")] + pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / special + #[prost(bool, repeated, tag = "4")] + pub is_special: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Generation { + /// / Request ID + #[prost(uint64, tag = "1")] + pub request_id: u64, + /// / Prefill tokens (optional) + #[prost(message, optional, tag = "2")] + pub prefill_tokens: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub tokens: ::core::option::Option, + /// / Complete generated text + #[prost(message, optional, tag = "4")] + pub generated_text: ::core::option::Option, + /// / Top tokens + #[prost(message, repeated, tag = "5")] + pub top_tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchRequest { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub batch_id: u64, + /// / Requests to keep + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchResponse { + /// / Filtered Batch (cached) + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillRequest { + /// / Batch + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillResponse { + /// / Generation + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeRequest { + /// / Cached batches + #[prost(message, repeated, tag = "1")] + pub batches: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeResponse { + /// / Decodes + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupRequest { + /// / Batch to warmup on + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, + #[prost(uint32, tag = "2")] + pub max_input_length: u32, + #[prost(uint32, tag = "3")] + pub max_prefill_tokens: u32, + #[prost(uint32, tag = "4")] + pub max_total_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupResponse { + /// / Maximum number of tokens supported by the model + #[prost(uint32, optional, tag = "1")] + pub max_supported_total_tokens: ::core::option::Option, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum GrammarType { + None = 0, + Json = 1, + Regex = 2, +} +impl GrammarType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + GrammarType::None => "GRAMMAR_TYPE_NONE", + GrammarType::Json => "GRAMMAR_TYPE_JSON", + GrammarType::Regex => "GRAMMAR_TYPE_REGEX", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GRAMMAR_TYPE_NONE" => Some(Self::None), + "GRAMMAR_TYPE_JSON" => Some(Self::Json), + "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FinishReason { + Length = 0, + EosToken = 1, + StopSequence = 2, +} +impl FinishReason { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + FinishReason::Length => "FINISH_REASON_LENGTH", + FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", + FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FINISH_REASON_LENGTH" => Some(Self::Length), + "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), + "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), + _ => None, + } + } +} +/// Generated client implementations. +pub mod text_generation_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct TextGenerationServiceClient { + inner: tonic::client::Grpc, + } + impl TextGenerationServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TextGenerationServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TextGenerationServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// / Model Info + pub async fn info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Info", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info")); + self.inner.unary(req, path, codec).await + } + /// / Service discovery + pub async fn service_discovery( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ServiceDiscovery", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "generate.v2.TextGenerationService", + "ServiceDiscovery", + ), + ); + self.inner.unary(req, path, codec).await + } + /// / Empties batch cache + pub async fn clear_cache( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ClearCache", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"), + ); + self.inner.unary(req, path, codec).await + } + /// / Remove requests from a cached batch + pub async fn filter_batch( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/FilterBatch", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"), + ); + self.inner.unary(req, path, codec).await + } + /// / Warmup the model and compute max cache size + pub async fn warmup( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Warmup", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup")); + self.inner.unary(req, path, codec).await + } + /// / Prefill batch and decode first token + pub async fn prefill( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Prefill", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill")); + self.inner.unary(req, path, codec).await + } + /// / Decode token for a list of prefilled batches + pub async fn decode( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Decode", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode")); + self.inner.unary(req, path, codec).await + } + /// / Health check + pub async fn health( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Health", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health")); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/router/client/src/pb/mod.rs b/router/client/src/pb/mod.rs new file mode 100644 index 00000000..095ead1f --- /dev/null +++ b/router/client/src/pb/mod.rs @@ -0,0 +1,6 @@ +// This file is @generated by prost-build. +pub mod generate { + pub mod v2 { + include!("generate.v2.rs"); + } +} diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index ba6f520d..926de0fa 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -39,8 +39,25 @@ impl SchedulerV2 { speculate: u32, generation_health: Arc, ) -> Self { +<<<<<<< HEAD:router/src/infer/v2/scheduler.rs let queue = Queue::new(requires_padding, 16, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); +======= + // Infer shared state + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + let block_size = std::env::var("BLOCK_SIZE") + .map(|b| b.parse().unwrap_or(block_size)) + .unwrap_or(block_size); + let queue = Queue::new(requires_padding, block_size, window_size, speculate); + let shared = Arc::new(Shared { + batching_task: Notify::new(), + }); +>>>>>>> Using flash decoding:router/src/infer.rs // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 583337bd..a564237c 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,6 @@ import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -21,7 +22,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -32,7 +40,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -56,64 +65,100 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + if FLASH_DECODING: + max_q = 1 + max_k = max_s + import flash_attn_2_cuda - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + out2 = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, None, - "auto", - 1.0, + cu_seqlen_k, + cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + False, # return softmax + None, # generator ) + return out2[0] else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + from vllm._C import ops - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) try: diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 7f086b68..db79c589 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -55,7 +55,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( @@ -66,7 +67,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, BLOCK_SIZE, max_s, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 91ed5818..c5c485de 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,6 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -26,7 +27,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -37,7 +45,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -61,6 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py new file mode 100644 index 00000000..df6b1ade --- /dev/null +++ b/server/text_generation_server/models/cache_manager.py @@ -0,0 +1,158 @@ +import math +import torch + +from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING + +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +# Will be set in warmup +CACHE_MANAGER: Optional["CacheManager"] = None + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + repeat_slots: bool, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks + self.repeat_slots = repeat_slots + + element_size = torch.tensor([], dtype=dtype).element_size() + if SYSTEM == "xpu": + x = 1 + else: + x = self.block_size // element_size + + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, self.block_size, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, self.block_size, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int64 + ).view(num_blocks, self.block_size) + + def allocate( + self, + needed_blocks_slots: List[Tuple[int, int]], + blocks: int, + max_blocks: int, + device: torch.device, + ): + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + if blocks > len(free_block_indices): + raise RuntimeError( + f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" + ) + + # Slice by the number of required blocks + block_indices = free_block_indices[:blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(needed_blocks_slots), max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks : cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + all_slots = self.slots[allocated_blocks].flatten() + + # Repeat slots in the case of context sliding window + if needed_slots > len(all_slots) and self.repeat_slots: + repeats = math.ceil(needed_slots / len(all_slots)) + all_slots = all_slots.repeat(repeats) + + allocated_slots = all_slots[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + block_tables = block_tables + block_tables_tensor = block_tables_tensor.to(device) + slots = torch.concat(slots).to(device) + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + return block_tables, block_tables_tensor, slots + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + + +def set_cache_manager( + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + repeat_slots: bool, + dtype: torch.dtype, + device: torch.device, +) -> CacheManager: + global CACHE_MANAGER + if CACHE_MANAGER is not None: + del CACHE_MANAGER + torch.cuda.empty_cache() + + CACHE_MANAGER = CacheManager( + num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device + ) + return CACHE_MANAGER + + +def get_cache_manager() -> CacheManager: + global CACHE_MANAGER + if CACHE_MANAGER is None: + raise RuntimeError("cache manager was not initialized") + + return CACHE_MANAGER diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 2850a6f3..7ccdf4a2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -259,8 +260,9 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) @@ -312,7 +314,8 @@ class FlashCohereAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ) @@ -386,8 +389,9 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -400,8 +404,9 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ) @@ -464,6 +469,24 @@ class FlashCohereModel(torch.nn.Module): ) residual = None + if cu_seqlen_prefill is None and FLASH_DECODING: + cu_seqlen_q = torch.arange( + input_lengths.shape[0] + 1, + device=input_ids.device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.cat( + [ + torch.zeros( + (1,), device=input_lengths.device, dtype=input_lengths.dtype + ), + input_lengths.cumsum(dim=-1), + ] + ).to(dtype=torch.int32) + else: + cu_seqlen_q = None + cu_seqlen_k = input_lengths + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -473,8 +496,9 @@ class FlashCohereModel(torch.nn.Module): cu_seqlen_prefill, kv_cache[i], block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d56e4ef..74dc9cf7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -344,6 +344,7 @@ class DbrxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 82891823..a885dad6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -253,6 +253,7 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 7e7510c7..ef238297 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -253,6 +253,7 @@ class FlashGPT2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6b82aeca..f33b1622 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -172,7 +173,8 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, adapter_data, ): @@ -192,10 +194,10 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor - attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: + attn_output = torch.empty_like(query) # flash attention attention( query, @@ -208,15 +210,16 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention( - attn_output, + attn_output = paged_attention( + None, query, kv_cache[0], kv_cache[1], self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ) @@ -353,7 +356,8 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, adapter_data, ): @@ -368,7 +372,8 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, adapter_data, ) @@ -438,6 +443,23 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) + if cu_seqlen_prefill is None and FLASH_DECODING: + cu_seqlen_q = torch.arange( + input_lengths.shape[0] + 1, + device=inputs_embeds.device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.cat( + [ + torch.zeros( + (1,), device=input_lengths.device, dtype=input_lengths.dtype + ), + input_lengths.cumsum(dim=-1), + ] + ).to(dtype=torch.int32) + else: + cu_seqlen_q = None + cu_seqlen_k = input_lengths residual = None for i, layer in enumerate(self.layers): @@ -450,7 +472,8 @@ class FlashLlamaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, adapter_data, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d1ba5564..673e501b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -237,6 +237,7 @@ class MistralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2e839d15..2d7b023f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -299,6 +299,7 @@ class MixtralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b87fd4ca..d4e7713d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -176,6 +176,7 @@ class FlashNeoxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 3f445f97..5872b59f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -215,6 +215,7 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 69f38c3a..14aee59b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -157,6 +157,7 @@ class Qwen2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 04d4ba51..735d0b90 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -225,6 +225,7 @@ class FlashRWAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) @@ -348,6 +349,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index badfc367..3873c653 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -309,6 +309,7 @@ class FlashMQAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index f6a2e15d..4450deb8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -263,6 +263,7 @@ class Starcoder2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a0a78b33..a3687d95 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -46,7 +46,7 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = 16 +BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16 # Will be set in init SLIDING_WINDOW: Optional[int] = None diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index bde86e6e..11693436 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -6,6 +6,9 @@ from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +if FLASH_DECODING: + logger.info("Using FLASH_DECODING") if cuda_graphs is not None: try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 011e0f63..2412e4f7 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -67,7 +67,7 @@ elif is_ipex_available(): synchronize = noop get_free_memory = get_cpu_free_memory else: - SYSTEM = "cpu" + SYSTEM = "ipex" empty_cache = noop synchronize = noop