From 9d0ca503a898d59b732436bb9eb0457c04efc30a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?icyboy=E2=84=A2?= Date: Mon, 1 Jul 2024 20:17:22 +0800 Subject: [PATCH 01/13] fix AttributeError: 'MixtralLayer' object has no attribute 'mlp' (#2123) https://github.com/huggingface/text-generation-inference/issues/2122 --- server/text_generation_server/models/flash_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 209eca83..0f5746de 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -153,7 +153,7 @@ class BaseFlashMistral(FlashCausalLM): # TODO: this is a hack to avoid the gate_proj for # FlashStarcoder2 that doesnt have these layers - if hasattr(layer.mlp, "gate_up_proj"): + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, From 5da4cfab1c211ff3e2aefbd0358f714970fb8360 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 1 Jul 2024 20:32:54 +0800 Subject: [PATCH 02/13] refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel platform (#2132) * refine get xpu free memory Signed-off-by: Wang, Yi A * enable qwen2 in xpu Signed-off-by: Wang, Yi A * enable gemma/gemma2/phi in intel platform Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/ipex.py | 3 ++- server/text_generation_server/models/flash_gemma.py | 8 ++++++++ server/text_generation_server/models/flash_gemma2.py | 8 ++++++++ server/text_generation_server/models/flash_phi.py | 8 ++++++++ server/text_generation_server/models/flash_qwen2.py | 8 ++++++++ server/text_generation_server/utils/import_utils.py | 12 ++++++++---- 6 files changed, 42 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index bfab0119..7f086b68 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -14,6 +14,7 @@ def attention( max_s, softmax_scale, window_size_left=-1, + causal=True, ): # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( @@ -28,7 +29,7 @@ def attention( 0.0, softmax_scale, False, - True, + causal, False, None, ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index aa1ae9ac..7e2b8780 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma is only available on GPU") diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py index 9608113b..86cfc7e2 100644 --- a/server/text_generation_server/models/flash_gemma2.py +++ b/server/text_generation_server/models/flash_gemma2.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 7e108d05..a530d1c3 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashPhi is only available on GPU") diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 23528f0b..cd6078f1 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashQwen2 is only available on GPU") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 6d921721..011e0f63 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,6 +1,7 @@ import torch from loguru import logger import subprocess +import os def is_ipex_available(): @@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - query = f"xpu-smi dump -d {device_id} -m 18 -n 1" - output = subprocess.check_output(query.split()).decode("utf-8").split("\n") - used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 - free_memory = int(total_memory * 0.95 - used_memory) + memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + free_memory = max( + 0, + int( + total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) + ), + ) return free_memory From 9eefb2f672052826ad4786aae463452d188bb75b Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 1 Jul 2024 09:08:05 -0400 Subject: [PATCH 03/13] fix: prefer serde structs over custom functions (#2127) * fix: prefer enum for chat object * fix: adjust typo * fix: enum CompletionType not ObjectType * fix: adjust typo * feat: leverage serde for conditional deser * fix: adjust HubTokenizerConfig after rebase * fix: update create_post_processor logic for token type * fix: adjust unwrap syntax in template * Fixing the post processor. --------- Co-authored-by: Nicolas Patry --- router/src/infer/mod.rs | 28 +++- router/src/lib.rs | 315 +++++++++++++++++++--------------------- router/src/main.rs | 29 ++-- router/src/server.rs | 38 ++--- 4 files changed, 207 insertions(+), 203 deletions(-) diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3..49282eb9 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, +}; +use crate::{ + FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -270,7 +272,11 @@ struct ChatTemplate { } impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + fn new( + template: String, + bos_token: Option, + eos_token: Option, + ) -> Self { let mut env = Box::new(Environment::new()); // enable things like .strip() or .capitalize() env.set_unknown_method_callback(pycompat::unknown_method_callback); @@ -287,8 +293,8 @@ impl ChatTemplate { Self { template, - bos_token, - eos_token, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, } } @@ -301,9 +307,9 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { + last_message.content.push(MessageChunk::Text { text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); + }); } } } @@ -340,6 +346,14 @@ impl ToolGrammar { .unwrap_or_else(|| panic!("Tool with name {} not found", name)) .clone()] } + ToolType::Function { function } => { + let tool = req_tools + .iter() + .find(|tool| tool.function.name == function.name) + .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) + .clone(); + vec![tool] + } ToolType::OneOf => req_tools.to_owned(), }; diff --git a/router/src/lib.rs b/router/src/lib.rs index a5b97af3..9ecfa051 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,23 +53,40 @@ pub enum ChatTemplateVersions { Multiple(Vec), } +use std::path::Path; + #[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub bos_token: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub eos_token: Option, + pub bos_token: Option, + pub eos_token: Option, pub tokenizer_class: Option, pub add_bos_token: Option, pub add_eos_token: Option, } impl HubTokenizerConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(untagged)] +pub enum TokenizerConfigToken { + String(String), + Object { content: String }, +} + +impl TokenizerConfigToken { + pub fn as_str(&self) -> &str { + match self { + TokenizerConfigToken::String(s) => s, + TokenizerConfigToken::Object { content } => content, + } } } @@ -100,9 +117,10 @@ pub struct HubProcessorConfig { } impl HubProcessorConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) } } @@ -121,35 +139,6 @@ pub(crate) enum GrammarType { Regex(String), } -mod token_serde { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(Some(s)), - Value::Object(map) => { - if let Some(content) = map.get("content").and_then(|v| v.as_str()) { - Ok(Some(content.to_string())) - } else { - Err(de::Error::custom( - "content key not found in structured token", - )) - } - } - Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), - } - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters { } } -mod prompt_serde { - use serde::{self, Deserialize, Deserializer}; - use serde_json::Value; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[serde(try_from = "PromptDeserializer")] +pub struct Prompt(pub Vec); - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Deserialize)] +#[serde(untagged)] +enum PromptDeserializer { + Single(String), + Multiple(Vec), +} + +impl TryFrom for Prompt { + type Error = String; + + fn try_from(value: PromptDeserializer) -> Result { match value { - Value::String(s) => Ok(vec![s]), - Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( - "Empty array detected. Do not use an empty array for the prompt.", - )), - Value::Array(arr) => arr - .iter() - .map(|v| match v { - Value::String(s) => Ok(s.to_owned()), - _ => Err(serde::de::Error::custom("Expected a string")), - }) - .collect(), - _ => Err(serde::de::Error::custom( - "Expected a string or an array of strings", - )), + PromptDeserializer::Single(s) => Ok(Prompt(vec![s])), + PromptDeserializer::Multiple(v) => { + if v.is_empty() { + Err( + "Empty array detected. Do not use an empty array for the prompt." + .to_string(), + ) + } else { + Ok(Prompt(v)) + } + } } } } @@ -396,8 +388,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] - #[serde(deserialize_with = "prompt_serde::deserialize")] - pub prompt: Vec, + pub prompt: Prompt, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] @@ -445,7 +436,6 @@ pub struct CompletionRequest { #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] pub(crate) struct Completion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -562,6 +551,15 @@ pub(crate) struct Usage { pub total_tokens: u32, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum CompletionType { + #[serde(rename = "chat.completion.chunk")] + ChatCompletionChunk(ChatCompletionChunk), + #[serde(rename = "chat.completion")] + ChatCompletion(ChatCompletion), +} + impl ChatCompletion { pub(crate) fn new( model: String, @@ -598,7 +596,6 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "chat.completion".into(), created, model, system_fingerprint, @@ -620,7 +617,6 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: String, pub created: u64, pub choices: Vec, pub model: String, @@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: String, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -710,7 +705,6 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "chat.completion.chunk".to_string(), created, model, system_fingerprint, @@ -821,7 +815,6 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, /// Response format constraints for the generation. @@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option { "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } -#[derive(Clone, Deserialize, ToSchema, Serialize)] -enum ToolType { - FunctionName(String), + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ToolType { OneOf, + FunctionName(String), + Function { function: FunctionName }, } -/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) -mod deserialize_tool_choice { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FunctionName { + pub name: String, +} - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(from = "ToolTypeDeserializer")] +pub struct ToolChoice(pub Option); +#[derive(Deserialize)] +#[serde(untagged)] +enum ToolTypeDeserializer { + None(Option), + Some(ToolType), +} + +impl From for ToolChoice { + fn from(value: ToolTypeDeserializer) -> Self { match value { - Value::String(s) => match s.as_str() { - "none" => Ok(None), - "auto" => Ok(Some(ToolType::OneOf)), - _ => Ok(Some(ToolType::FunctionName(s))), + ToolTypeDeserializer::None(opt) => match opt.as_deref() { + Some("none") => ToolChoice(None), + Some("auto") => ToolChoice(Some(ToolType::OneOf)), + Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), + None => ToolChoice(Some(ToolType::OneOf)), }, - Value::Object(map) => { - if let Some(content) = map - .get("function") - .and_then(|v| v.get("name")) - .and_then(|v| v.as_str()) - { - Ok(Some(ToolType::FunctionName(content.to_string()))) - } else { - Err(de::Error::custom("function key not found in tool choice")) - } - } - Value::Null => Ok(Some(ToolType::OneOf)), - _ => Err(de::Error::custom("invalid token format")), + ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -950,26 +940,16 @@ pub(crate) struct ToolCall { } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Url { +pub struct Url { url: String, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct ImageUrl { - image_url: Url, -} - -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Text { - text: String, -} - #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] -enum MessageChunk { - Text(Text), - ImageUrl(ImageUrl), +pub enum MessageChunk { + Text { text: String }, + ImageUrl { image_url: Url }, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -977,35 +957,31 @@ pub struct Message { #[schema(example = "user")] role: String, #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] - content: Vec, + pub content: MessageContent, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] name: Option, } -mod message_content_serde { - use super::*; - use serde::{Deserialize, Deserializer}; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + SingleText(String), + MultipleChunks(Vec), +} - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Message { - Text(String), - Chunks(Vec), - } - let message: Message = Deserialize::deserialize(deserializer)?; - let chunks = match message { - Message::Text(text) => { - vec![MessageChunk::Text(Text { text })] +// Pushing a chunk to a single text message will convert it to a multiple chunks message +impl MessageContent { + pub fn push(&mut self, chunk: MessageChunk) { + match self { + MessageContent::SingleText(text) => { + *self = + MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); } - Message::Chunks(s) => s, - }; - Ok(chunks) + MessageContent::MultipleChunks(chunks) => { + chunks.push(chunk); + } + } } } @@ -1021,18 +997,17 @@ impl From for TextMessage { fn from(value: Message) -> Self { TextMessage { role: value.role, - content: value - .content - .into_iter() - .map(|c| match c { - MessageChunk::Text(Text { text }) => text, - MessageChunk::ImageUrl(image) => { - let url = image.image_url.url; - format!("![]({url})") - } - }) - .collect::>() - .join(""), + content: match value.content { + MessageContent::SingleText(text) => text, + MessageContent::MultipleChunks(chunks) => chunks + .into_iter() + .map(|chunk| match chunk { + MessageChunk::Text { text } => text, + MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), + }) + .collect::>() + .join(""), + }, } } } @@ -1240,9 +1215,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::String( + "<|begin▁of▁sentence|>".to_string() + )) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::String( + "<|end▁of▁sentence|>".to_string() + )) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); // in this case we expect the tokens to be encoded as structured tokens // we want the content of the structured token @@ -1275,9 +1257,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::Object { + content: "<|begin▁of▁sentence|>".to_string() + }) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::Object { + content: "<|end▁of▁sentence|>".to_string() + }) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } #[test] @@ -1295,9 +1284,7 @@ mod tests { request.messages[0], Message { role: "user".to_string(), - content: vec![MessageChunk::Text(Text { - text: "What is Deep Learning?".to_string() - }),], + content: MessageContent::SingleText("What is Deep Learning?".to_string()), name: None } ); @@ -1321,10 +1308,10 @@ mod tests { request.messages[0], Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), name: None } ); @@ -1334,10 +1321,10 @@ mod tests { fn text_message_convert() { let message = Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), name: None }; let textmsg: TextMessage = message.into(); diff --git a/router/src/main.rs b/router/src/main.rs index 8a5cf459..8618f57e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -553,11 +553,11 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { let bos_token_id = tokenizer - .token_to_id(bos) + .token_to_id(bos.as_str()) .expect("Should have found the bos token id"); - special_tokens.push((bos.clone(), bos_token_id)); - single.push(format!("{}:0", bos)); - pair.push(format!("{}:0", bos)); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); } } @@ -567,17 +567,17 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { let eos_token_id = tokenizer - .token_to_id(eos) + .token_to_id(eos.as_str()) .expect("Should have found the eos token id"); - special_tokens.push((eos.clone(), eos_token_id)); - single.push(format!("{}:0", eos)); - pair.push(format!("{}:0", eos)); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); } } if add_bos_token { if let Some(bos) = bos_token { - pair.push(format!("{}:1", bos)); + pair.push(format!("{}:1", bos.as_str())); } } @@ -585,7 +585,7 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos)); + pair.push(format!("{}:1", eos.as_str())); } } @@ -611,14 +611,15 @@ enum RouterError { #[cfg(test)] mod tests { use super::*; + use text_generation_router::TokenizerConfigToken; #[test] fn test_create_post_processor() { let tokenizer_config = HubTokenizerConfig { add_bos_token: None, add_eos_token: None, - bos_token: Some("".to_string()), - eos_token: Some("".to_string()), + bos_token: Some(TokenizerConfigToken::String("".to_string())), + eos_token: Some(TokenizerConfigToken::String("".to_string())), chat_template: None, tokenizer_class: None, completion_template: None, @@ -629,9 +630,9 @@ mod tests { let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); let expected = TemplateProcessing::builder() - .try_single(":0 $A:0 :1") + .try_single(":0 $A:0") .unwrap() - .try_pair(":0 $A:0 $B:1") + .try_pair(":0 $A:0 :1 $B:1") .unwrap() .special_tokens(vec![("".to_string(), 1)]) .build() diff --git a/router/src/server.rs b/router/src/server.rs index 0cb08d4e..d24774f9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,17 +12,18 @@ use crate::kserve::{ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, - HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, - Token, TokenizeResponse, Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, + Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, + Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, + CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, + VertexResponse, }; -use crate::{FunctionDefinition, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -635,7 +636,7 @@ async fn completions( )); } - if req.prompt.len() > info.max_client_batch_size { + if req.prompt.0.len() > info.max_client_batch_size { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -651,6 +652,7 @@ async fn completions( let generate_requests: Vec = req .prompt + .0 .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), @@ -705,7 +707,6 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, choices: vec![CompletionComplete { @@ -932,7 +933,6 @@ async fn completions( let response = Completion { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, model: info.model_id.clone(), system_fingerprint: format!( @@ -1153,14 +1153,16 @@ async fn chat_completions( }; event - .json_data(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + .json_data(CompletionType::ChatCompletionChunk( + ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + stream_token.details.map(|d| d.finish_reason.to_string()), + ), )) .unwrap_or_else(|e| { println!("Failed to serialize ChatCompletionChunk: {:?}", e); @@ -1228,7 +1230,7 @@ async fn chat_completions( (None, Some(generation.generated_text)) }; // build the complete response object with the full text - let response = ChatCompletion::new( + let response = CompletionType::ChatCompletion(ChatCompletion::new( model_id, system_fingerprint, output, @@ -1236,7 +1238,7 @@ async fn chat_completions( generation.details.unwrap(), logprobs, tool_calls, - ); + )); // wrap generation inside a Vec to match api-inference Ok((headers, Json(response)).into_response()) From 17cebc4506f5bbfead6f27f148a96554ab7228c4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 15:24:17 +0200 Subject: [PATCH 04/13] Fixing test. (#2152) From d0225b10156320f294647ac676c130d03626473d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 15:42:26 +0200 Subject: [PATCH 05/13] GH router. (#2153) --- .github/workflows/build.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3de270ea..6db7a505 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -80,6 +80,9 @@ jobs: uses: docker/setup-buildx-action@v3 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 From 4f55f15840ec8049aa8881d135038fb63b61953b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 23:25:54 +0200 Subject: [PATCH 06/13] Fixing baichuan override. (#2158) --- .../models/custom_modeling/flash_llama_modeling.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6b82aeca..0ea9f623 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -117,6 +117,11 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + # Setting defaults for baichuan custom config which doesn't apply them. + config.rope_theta = getattr(config, "rope_theta", 10000) + config.num_key_value_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads + ) self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, From 4327210e6b3e12db45b332d750fc8cb7da9b86c6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 23:28:00 +0200 Subject: [PATCH 07/13] [Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940) * Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually. --- router/src/infer/v2/scheduler.rs | 9 +- router/src/infer/v3/scheduler.rs | 8 +- .../layers/attention/__init__.py | 2 + .../layers/attention/common.py | 44 ++++++ .../layers/attention/cuda.py | 139 ++++++++++++------ .../layers/attention/ipex.py | 5 +- .../layers/attention/rocm.py | 15 +- .../text_generation_server/models/__init__.py | 3 +- .../custom_modeling/flash_cohere_modeling.py | 6 +- .../custom_modeling/flash_dbrx_modeling.py | 2 +- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../custom_modeling/flash_gpt2_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 3 +- .../custom_modeling/flash_mistral_modeling.py | 5 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_phi_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 2 +- .../custom_modeling/flash_rw_modeling.py | 4 +- .../flash_santacoder_modeling.py | 2 +- .../flash_starcoder2_modeling.py | 2 +- .../models/flash_causal_lm.py | 25 +++- .../text_generation_server/models/globals.py | 8 +- 24 files changed, 223 insertions(+), 75 deletions(-) create mode 100644 server/text_generation_server/layers/attention/common.py diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index ba6f520d..e4c3de26 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -39,7 +39,14 @@ impl SchedulerV2 { speculate: u32, generation_health: Arc, ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); + // Infer shared state + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd83..543ce89f 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,9 +39,15 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, - 16, + block_size, window_size, speculate, max_batch_total_tokens, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e74180e7..c8bccefe 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,6 +1,8 @@ from text_generation_server.utils.import_utils import SYSTEM import os +from .common import Seqlen + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py new file mode 100644 index 00000000..bd0717ce --- /dev/null +++ b/server/text_generation_server/layers/attention/common.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from text_generation_server.models.globals import FLASH_DECODING +import torch +from typing import Optional + + +if FLASH_DECODING: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__(self, input_lengths): + self.input_lengths = input_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + +else: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + + def clamp(self, max): + return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 583337bd..94b69899 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -21,7 +23,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -32,7 +41,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -53,7 +62,8 @@ def paged_attention( # # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] + # block_size = value_cache.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE @@ -62,58 +72,95 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + if FLASH_DECODING: + max_q = 1 + max_k = max_s + import flash_attn_2_cuda - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + out2 = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, None, - "auto", - 1.0, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + False, # return softmax + None, # generator ) + return out2[0] else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + input_lengths = seqlen.input_lengths + from vllm._C import ops - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + return out try: diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 7f086b68..db79c589 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -55,7 +55,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( @@ -66,7 +67,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, BLOCK_SIZE, max_s, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 91ed5818..36db12d0 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,6 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -26,7 +27,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -37,7 +45,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -61,6 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -119,6 +129,7 @@ def paged_attention( "auto", 1.0, ) + return out if ENGINE != "triton": diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f2f0f457..5ea43290 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,6 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -92,6 +92,7 @@ except ImportError as e: FLASH_ATTENTION = False if FLASH_ATTENTION: + __all__.append(FlashCausalLM) __all__.append(FlashGPT2) __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 2850a6f3..e088f9aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - slots, input_lengths, + slots, max_s, ): qkv = self.query_key_value(hidden_states) @@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module): ) residual = None + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d56e4ef..aea7f399 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a71de61f..cfa6b2fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 82891823..842df0d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 7e7510c7..9f800146 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0ea9f623..77a7e2d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d1ba5564..69ed5f64 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( + Seqlen, paged_attention, attention, reshape_and_cache, @@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2e839d15..2d6a7f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b87fd4ca..33aebc2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 3f445f97..f237ea37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 69f38c3a..2e281386 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 04d4ba51..e7614232 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index badfc367..30989a37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index f6a2e15d..df864bc1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a0a78b33..49a088a1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -30,10 +30,13 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, + FLASH_DECODING, + BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, MODEL_ID, ) +from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments @@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = 16 # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -856,7 +858,23 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if SYSTEM == "ipex" and device == torch.device("cpu"): + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + elif SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( @@ -908,6 +926,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } + input_lengths = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1067,6 +1086,7 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + input_lengths = Seqlen(input_lengths=input_lengths) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1153,6 +1173,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index bde86e6e..06035ccd 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,12 @@ from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +if FLASH_DECODING: + logger.info("Using FLASH_DECODING") + + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: @@ -15,8 +21,6 @@ if cuda_graphs is not None: ) else: cuda_graphs = None - - # sorting the cuda graphs in descending order helps reduce the # memory impact and results in less memory usage if cuda_graphs is not None: From 022f6515a48c26cd505b4bbfa8da3bd00e77f078 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 11:43:07 +0200 Subject: [PATCH 08/13] Fixing graph capture for flash decoding. (#2163) --- server/text_generation_server/models/flash_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 49a088a1..4f276ed4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -926,7 +926,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -939,7 +939,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -947,6 +947,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, From 5d97e0c4a3688ef462472167242c48570b8125c5 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 2 Jul 2024 17:56:07 +0800 Subject: [PATCH 09/13] fix FlashDecoding change's regression in intel platform (#2161) install triton because GPTQParams needs it. Signed-off-by: Wang, Yi A --- Dockerfile_intel | 2 ++ server/text_generation_server/layers/attention/ipex.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index a41fbc1e..3c060f19 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl +RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server @@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install triton WORKDIR /usr/src diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index db79c589..45a0a03e 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,6 +1,7 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.models.flash_causal_lm import BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen SUPPORTS_WINDOWING = False @@ -55,11 +56,10 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + seqlen: Seqlen, max_s: int, ): - return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, key_cache, @@ -67,8 +67,9 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - cu_seqlen_q, + seqlen.input_lengths, BLOCK_SIZE, max_s, None, ) + return out From b966bc0d35ef81d8cf2ba1e0f519d7e00c9e356e Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 2 Jul 2024 05:56:25 -0400 Subject: [PATCH 10/13] fix: use the base layers weight in mistral rocm (#2155) --- .../models/custom_modeling/flash_mistral_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 69ed5f64..396969cd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -315,7 +315,9 @@ class MistralMLP(nn.Module): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + _custom_C.LLMM_Silu( + self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 + ) return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) From dea9c0dc741875fde9225e6c2a51d7bb8fb052e4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 12:01:08 +0200 Subject: [PATCH 11/13] Fixing rocm. (#2164) --- server/text_generation_server/layers/attention/rocm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 36db12d0..99c490d5 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -2,6 +2,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.layers.attention import Seqlen from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -45,8 +46,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + input_lengths: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -70,7 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = cu_seqlen_k + input_lengths = input_lengths.input_lengths # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use From 963b6c6f0fce7bb791796a37e90f131740d38dff Mon Sep 17 00:00:00 2001 From: Guillaume LEGENDRE Date: Tue, 2 Jul 2024 12:45:38 +0200 Subject: [PATCH 12/13] Ci test (#2124) * first test with registry mirror * change push registry * remove comments * Move cache to push registry * fix registry url * Update .github/workflows/ci_build.yaml --------- Co-authored-by: Nicolas Patry --- .github/workflows/build.yaml | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 6db7a505..b0049701 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,12 +70,6 @@ jobs: echo "LABEL=${label_extension}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV - - name: Tailscale - uses: huggingface/tailscale-action@main - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }} - slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v3 with: @@ -90,12 +84,6 @@ jobs: registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Login to internal Container Registry - uses: docker/login-action@v3 - with: - username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} - password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} - registry: registry.internal.huggingface.tech - name: Login to Azure Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 @@ -110,7 +98,7 @@ jobs: uses: docker/metadata-action@v5 with: images: | - registry.internal.huggingface.tech/api-inference/community/text-generation-inference + registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference tags: | type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} # If main, release or tag @@ -122,7 +110,7 @@ jobs: flavor: | latest=auto images: | - registry.internal.huggingface.tech/api-inference/community/text-generation-inference + registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference ghcr.io/huggingface/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | @@ -143,12 +131,12 @@ jobs: DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min + cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min + cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - name: Final id: final run: | - echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" + echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" @@ -173,11 +161,6 @@ jobs: - name: Install run: | make install-integration-tests - - name: Tailscale - uses: huggingface/tailscale-action@main - if: needs.build-and-push.outputs.runs_on != 'amd-gpu-tgi' - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - name: Run tests run: | export DOCKER_VOLUME=/mnt/cache From 0759ec495e15a865d2a59befc2b796b5acc09b50 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 14:26:47 +0200 Subject: [PATCH 13/13] Hotfixing qwen2 and starcoder2 (which also get clamping). (#2167) --- .../models/custom_modeling/flash_qwen2_modeling.py | 2 +- .../models/custom_modeling/flash_starcoder2_modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 2e281386..1cc6a613 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index df864bc1..a0273c37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids,