diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 121917f0..11420909 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -16,7 +16,6 @@ jobs: build-and-push: outputs: docker_image: ${{ steps.final.outputs.docker_image }} - base_docker_image: ${{ steps.final.outputs.base_docker_image }} docker_devices: ${{ steps.final.outputs.docker_devices }} docker_volume: ${{ steps.final.outputs.docker_volume}} runs_on: ${{ steps.final.outputs.runs_on }} @@ -73,17 +72,13 @@ jobs: echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV - - name: Tailscale - uses: huggingface/tailscale-action@main - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }} - slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} - - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v3 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' @@ -93,13 +88,6 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Login to internal Container Registry - uses: docker/login-action@v3 - with: - username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} - password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} - registry: registry.internal.huggingface.tech - - name: Login to Azure Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 @@ -115,10 +103,9 @@ jobs: uses: docker/metadata-action@v5 with: images: | - registry.internal.huggingface.tech/api-inference/community/text-generation-inference + registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference tags: | type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} - # If main, release or tag - name: Extract metadata (tags, labels) for Docker if: ${{ github.event_name != 'pull_request' }} @@ -128,7 +115,7 @@ jobs: flavor: | latest=auto images: | - registry.internal.huggingface.tech/api-inference/community/text-generation-inference + registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference ghcr.io/huggingface/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | @@ -136,7 +123,6 @@ jobs: type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }} type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} - - name: Build and push Docker image id: build-and-push uses: docker/build-push-action@v4 @@ -150,30 +136,16 @@ jobs: DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - + cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min + cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - name: Final id: final run: | - echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" + echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" - if [[ ${{ inputs.hardware }} == "rocm" ]] - then - echo "base_docker_image=rocm/dev-ubuntu-22.04:6.1.1_hip_update" >> "$GITHUB_OUTPUT" - elif [[ ${{ inputs.hardware }} == "cuda" ]] - then - echo "base_docker_image=nvidia/cuda:12.1.0-base-ubuntu22.04" >> "$GITHUB_OUTPUT" - elif [[ ${{ inputs.hardware }} == "xpu" ]] - then - echo "base_docker_image=intel/intel-extension-for-pytorch:2.1.30-xpu" >> "$GITHUB_OUTPUT" - else - exit 1 - fi - if [[ ${{ inputs.hardware }} == "rocm" ]] then echo "docker_volume=/data/cache/.cache/huggingface/hub" >> "$GITHUB_OUTPUT" @@ -191,7 +163,7 @@ jobs: # Ideally, we would use the image from registry.internal.huggingface.tech but we can not login to the private registry outside of tailscale, # and even adding a previous job with tailscale login still results in `Docker login for 'registry.internal.huggingface.tech' failed with exit code 1`. container: - image: ${{ needs.build-and-push.outputs.base_docker_image }} + image: ${{ needs.build-and-push.outputs.docker_image }} options: --shm-size "16gb" --ipc host -v ${{ needs.build-and-push.outputs.docker_volume }}:/data steps: - name: Checkout repository @@ -207,8 +179,6 @@ jobs: echo "ls:" ls - pip3 install -U huggingface_hub - python3 integration-tests/clean_cache_and_download.py --token ${{ secrets.HF_TOKEN }} --cache-dir /data # Avoid permissions issues in the next step not run within docker (File was unable to be removed Error: EACCES). @@ -242,12 +212,6 @@ jobs: run: | make install-integration-tests - - name: Tailscale - uses: huggingface/tailscale-action@main - if: needs.build-and-push.outputs.runs_on != 'amd-gpu-tgi' - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - - name: Run tests run: | export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} diff --git a/Dockerfile_intel b/Dockerfile_intel index a41fbc1e..3c060f19 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl +RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server @@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install triton WORKDIR /usr/src diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3..49282eb9 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, +}; +use crate::{ + FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -270,7 +272,11 @@ struct ChatTemplate { } impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + fn new( + template: String, + bos_token: Option, + eos_token: Option, + ) -> Self { let mut env = Box::new(Environment::new()); // enable things like .strip() or .capitalize() env.set_unknown_method_callback(pycompat::unknown_method_callback); @@ -287,8 +293,8 @@ impl ChatTemplate { Self { template, - bos_token, - eos_token, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, } } @@ -301,9 +307,9 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { + last_message.content.push(MessageChunk::Text { text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); + }); } } } @@ -340,6 +346,14 @@ impl ToolGrammar { .unwrap_or_else(|| panic!("Tool with name {} not found", name)) .clone()] } + ToolType::Function { function } => { + let tool = req_tools + .iter() + .find(|tool| tool.function.name == function.name) + .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) + .clone(); + vec![tool] + } ToolType::OneOf => req_tools.to_owned(), }; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index ba6f520d..e4c3de26 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -39,7 +39,14 @@ impl SchedulerV2 { speculate: u32, generation_health: Arc, ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); + // Infer shared state + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd83..543ce89f 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,9 +39,15 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, - 16, + block_size, window_size, speculate, max_batch_total_tokens, diff --git a/router/src/lib.rs b/router/src/lib.rs index a5b97af3..9ecfa051 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,23 +53,40 @@ pub enum ChatTemplateVersions { Multiple(Vec), } +use std::path::Path; + #[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub bos_token: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub eos_token: Option, + pub bos_token: Option, + pub eos_token: Option, pub tokenizer_class: Option, pub add_bos_token: Option, pub add_eos_token: Option, } impl HubTokenizerConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(untagged)] +pub enum TokenizerConfigToken { + String(String), + Object { content: String }, +} + +impl TokenizerConfigToken { + pub fn as_str(&self) -> &str { + match self { + TokenizerConfigToken::String(s) => s, + TokenizerConfigToken::Object { content } => content, + } } } @@ -100,9 +117,10 @@ pub struct HubProcessorConfig { } impl HubProcessorConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) } } @@ -121,35 +139,6 @@ pub(crate) enum GrammarType { Regex(String), } -mod token_serde { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(Some(s)), - Value::Object(map) => { - if let Some(content) = map.get("content").and_then(|v| v.as_str()) { - Ok(Some(content.to_string())) - } else { - Err(de::Error::custom( - "content key not found in structured token", - )) - } - } - Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), - } - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters { } } -mod prompt_serde { - use serde::{self, Deserialize, Deserializer}; - use serde_json::Value; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[serde(try_from = "PromptDeserializer")] +pub struct Prompt(pub Vec); - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Deserialize)] +#[serde(untagged)] +enum PromptDeserializer { + Single(String), + Multiple(Vec), +} + +impl TryFrom for Prompt { + type Error = String; + + fn try_from(value: PromptDeserializer) -> Result { match value { - Value::String(s) => Ok(vec![s]), - Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( - "Empty array detected. Do not use an empty array for the prompt.", - )), - Value::Array(arr) => arr - .iter() - .map(|v| match v { - Value::String(s) => Ok(s.to_owned()), - _ => Err(serde::de::Error::custom("Expected a string")), - }) - .collect(), - _ => Err(serde::de::Error::custom( - "Expected a string or an array of strings", - )), + PromptDeserializer::Single(s) => Ok(Prompt(vec![s])), + PromptDeserializer::Multiple(v) => { + if v.is_empty() { + Err( + "Empty array detected. Do not use an empty array for the prompt." + .to_string(), + ) + } else { + Ok(Prompt(v)) + } + } } } } @@ -396,8 +388,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] - #[serde(deserialize_with = "prompt_serde::deserialize")] - pub prompt: Vec, + pub prompt: Prompt, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] @@ -445,7 +436,6 @@ pub struct CompletionRequest { #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] pub(crate) struct Completion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -562,6 +551,15 @@ pub(crate) struct Usage { pub total_tokens: u32, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum CompletionType { + #[serde(rename = "chat.completion.chunk")] + ChatCompletionChunk(ChatCompletionChunk), + #[serde(rename = "chat.completion")] + ChatCompletion(ChatCompletion), +} + impl ChatCompletion { pub(crate) fn new( model: String, @@ -598,7 +596,6 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "chat.completion".into(), created, model, system_fingerprint, @@ -620,7 +617,6 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: String, pub created: u64, pub choices: Vec, pub model: String, @@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: String, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -710,7 +705,6 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "chat.completion.chunk".to_string(), created, model, system_fingerprint, @@ -821,7 +815,6 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, /// Response format constraints for the generation. @@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option { "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } -#[derive(Clone, Deserialize, ToSchema, Serialize)] -enum ToolType { - FunctionName(String), + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ToolType { OneOf, + FunctionName(String), + Function { function: FunctionName }, } -/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) -mod deserialize_tool_choice { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FunctionName { + pub name: String, +} - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(from = "ToolTypeDeserializer")] +pub struct ToolChoice(pub Option); +#[derive(Deserialize)] +#[serde(untagged)] +enum ToolTypeDeserializer { + None(Option), + Some(ToolType), +} + +impl From for ToolChoice { + fn from(value: ToolTypeDeserializer) -> Self { match value { - Value::String(s) => match s.as_str() { - "none" => Ok(None), - "auto" => Ok(Some(ToolType::OneOf)), - _ => Ok(Some(ToolType::FunctionName(s))), + ToolTypeDeserializer::None(opt) => match opt.as_deref() { + Some("none") => ToolChoice(None), + Some("auto") => ToolChoice(Some(ToolType::OneOf)), + Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), + None => ToolChoice(Some(ToolType::OneOf)), }, - Value::Object(map) => { - if let Some(content) = map - .get("function") - .and_then(|v| v.get("name")) - .and_then(|v| v.as_str()) - { - Ok(Some(ToolType::FunctionName(content.to_string()))) - } else { - Err(de::Error::custom("function key not found in tool choice")) - } - } - Value::Null => Ok(Some(ToolType::OneOf)), - _ => Err(de::Error::custom("invalid token format")), + ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -950,26 +940,16 @@ pub(crate) struct ToolCall { } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Url { +pub struct Url { url: String, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct ImageUrl { - image_url: Url, -} - -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Text { - text: String, -} - #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] -enum MessageChunk { - Text(Text), - ImageUrl(ImageUrl), +pub enum MessageChunk { + Text { text: String }, + ImageUrl { image_url: Url }, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -977,35 +957,31 @@ pub struct Message { #[schema(example = "user")] role: String, #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] - content: Vec, + pub content: MessageContent, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] name: Option, } -mod message_content_serde { - use super::*; - use serde::{Deserialize, Deserializer}; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + SingleText(String), + MultipleChunks(Vec), +} - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Message { - Text(String), - Chunks(Vec), - } - let message: Message = Deserialize::deserialize(deserializer)?; - let chunks = match message { - Message::Text(text) => { - vec![MessageChunk::Text(Text { text })] +// Pushing a chunk to a single text message will convert it to a multiple chunks message +impl MessageContent { + pub fn push(&mut self, chunk: MessageChunk) { + match self { + MessageContent::SingleText(text) => { + *self = + MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); } - Message::Chunks(s) => s, - }; - Ok(chunks) + MessageContent::MultipleChunks(chunks) => { + chunks.push(chunk); + } + } } } @@ -1021,18 +997,17 @@ impl From for TextMessage { fn from(value: Message) -> Self { TextMessage { role: value.role, - content: value - .content - .into_iter() - .map(|c| match c { - MessageChunk::Text(Text { text }) => text, - MessageChunk::ImageUrl(image) => { - let url = image.image_url.url; - format!("![]({url})") - } - }) - .collect::>() - .join(""), + content: match value.content { + MessageContent::SingleText(text) => text, + MessageContent::MultipleChunks(chunks) => chunks + .into_iter() + .map(|chunk| match chunk { + MessageChunk::Text { text } => text, + MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), + }) + .collect::>() + .join(""), + }, } } } @@ -1240,9 +1215,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::String( + "<|begin▁of▁sentence|>".to_string() + )) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::String( + "<|end▁of▁sentence|>".to_string() + )) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); // in this case we expect the tokens to be encoded as structured tokens // we want the content of the structured token @@ -1275,9 +1257,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::Object { + content: "<|begin▁of▁sentence|>".to_string() + }) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::Object { + content: "<|end▁of▁sentence|>".to_string() + }) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } #[test] @@ -1295,9 +1284,7 @@ mod tests { request.messages[0], Message { role: "user".to_string(), - content: vec![MessageChunk::Text(Text { - text: "What is Deep Learning?".to_string() - }),], + content: MessageContent::SingleText("What is Deep Learning?".to_string()), name: None } ); @@ -1321,10 +1308,10 @@ mod tests { request.messages[0], Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), name: None } ); @@ -1334,10 +1321,10 @@ mod tests { fn text_message_convert() { let message = Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), name: None }; let textmsg: TextMessage = message.into(); diff --git a/router/src/main.rs b/router/src/main.rs index 8a5cf459..8618f57e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -553,11 +553,11 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { let bos_token_id = tokenizer - .token_to_id(bos) + .token_to_id(bos.as_str()) .expect("Should have found the bos token id"); - special_tokens.push((bos.clone(), bos_token_id)); - single.push(format!("{}:0", bos)); - pair.push(format!("{}:0", bos)); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); } } @@ -567,17 +567,17 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { let eos_token_id = tokenizer - .token_to_id(eos) + .token_to_id(eos.as_str()) .expect("Should have found the eos token id"); - special_tokens.push((eos.clone(), eos_token_id)); - single.push(format!("{}:0", eos)); - pair.push(format!("{}:0", eos)); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); } } if add_bos_token { if let Some(bos) = bos_token { - pair.push(format!("{}:1", bos)); + pair.push(format!("{}:1", bos.as_str())); } } @@ -585,7 +585,7 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos)); + pair.push(format!("{}:1", eos.as_str())); } } @@ -611,14 +611,15 @@ enum RouterError { #[cfg(test)] mod tests { use super::*; + use text_generation_router::TokenizerConfigToken; #[test] fn test_create_post_processor() { let tokenizer_config = HubTokenizerConfig { add_bos_token: None, add_eos_token: None, - bos_token: Some("".to_string()), - eos_token: Some("".to_string()), + bos_token: Some(TokenizerConfigToken::String("".to_string())), + eos_token: Some(TokenizerConfigToken::String("".to_string())), chat_template: None, tokenizer_class: None, completion_template: None, @@ -629,9 +630,9 @@ mod tests { let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); let expected = TemplateProcessing::builder() - .try_single(":0 $A:0 :1") + .try_single(":0 $A:0") .unwrap() - .try_pair(":0 $A:0 $B:1") + .try_pair(":0 $A:0 :1 $B:1") .unwrap() .special_tokens(vec![("".to_string(), 1)]) .build() diff --git a/router/src/server.rs b/router/src/server.rs index 0cb08d4e..d24774f9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,17 +12,18 @@ use crate::kserve::{ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, - HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, - Token, TokenizeResponse, Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, + Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, + Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, + CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, + VertexResponse, }; -use crate::{FunctionDefinition, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -635,7 +636,7 @@ async fn completions( )); } - if req.prompt.len() > info.max_client_batch_size { + if req.prompt.0.len() > info.max_client_batch_size { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -651,6 +652,7 @@ async fn completions( let generate_requests: Vec = req .prompt + .0 .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), @@ -705,7 +707,6 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, choices: vec![CompletionComplete { @@ -932,7 +933,6 @@ async fn completions( let response = Completion { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, model: info.model_id.clone(), system_fingerprint: format!( @@ -1153,14 +1153,16 @@ async fn chat_completions( }; event - .json_data(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + .json_data(CompletionType::ChatCompletionChunk( + ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + stream_token.details.map(|d| d.finish_reason.to_string()), + ), )) .unwrap_or_else(|e| { println!("Failed to serialize ChatCompletionChunk: {:?}", e); @@ -1228,7 +1230,7 @@ async fn chat_completions( (None, Some(generation.generated_text)) }; // build the complete response object with the full text - let response = ChatCompletion::new( + let response = CompletionType::ChatCompletion(ChatCompletion::new( model_id, system_fingerprint, output, @@ -1236,7 +1238,7 @@ async fn chat_completions( generation.details.unwrap(), logprobs, tool_calls, - ); + )); // wrap generation inside a Vec to match api-inference Ok((headers, Json(response)).into_response()) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e74180e7..c8bccefe 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,6 +1,8 @@ from text_generation_server.utils.import_utils import SYSTEM import os +from .common import Seqlen + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py new file mode 100644 index 00000000..bd0717ce --- /dev/null +++ b/server/text_generation_server/layers/attention/common.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from text_generation_server.models.globals import FLASH_DECODING +import torch +from typing import Optional + + +if FLASH_DECODING: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__(self, input_lengths): + self.input_lengths = input_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + +else: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + + def clamp(self, max): + return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 583337bd..94b69899 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -21,7 +23,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -32,7 +41,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -53,7 +62,8 @@ def paged_attention( # # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] + # block_size = value_cache.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE @@ -62,58 +72,95 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + if FLASH_DECODING: + max_q = 1 + max_k = max_s + import flash_attn_2_cuda - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + out2 = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, None, - "auto", - 1.0, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + False, # return softmax + None, # generator ) + return out2[0] else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + input_lengths = seqlen.input_lengths + from vllm._C import ops - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + return out try: diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index bfab0119..45a0a03e 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,6 +1,7 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.models.flash_causal_lm import BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen SUPPORTS_WINDOWING = False @@ -14,6 +15,7 @@ def attention( max_s, softmax_scale, window_size_left=-1, + causal=True, ): # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( @@ -28,7 +30,7 @@ def attention( 0.0, softmax_scale, False, - True, + causal, False, None, ) @@ -54,10 +56,10 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ): - return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, key_cache, @@ -65,8 +67,9 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - input_lengths, + seqlen.input_lengths, BLOCK_SIZE, max_s, None, ) + return out diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 91ed5818..99c490d5 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,6 +1,8 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.layers.attention import Seqlen from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -26,7 +28,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -37,7 +46,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + input_lengths: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -61,6 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = input_lengths.input_lengths # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -119,6 +129,7 @@ def paged_attention( "auto", 1.0, ) + return out if ENGINE != "triton": diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 27be121d..6c1b09c0 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,6 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -92,6 +92,7 @@ except ImportError as e: FLASH_ATTENTION = False if FLASH_ATTENTION: + __all__.append(FlashCausalLM) __all__.append(FlashGPT2) __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 2850a6f3..e088f9aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - slots, input_lengths, + slots, max_s, ): qkv = self.query_key_value(hidden_states) @@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module): ) residual = None + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d56e4ef..aea7f399 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a71de61f..cfa6b2fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 82891823..842df0d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 9deae8be..5d8c9515 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6b82aeca..77a7e2d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -117,6 +118,11 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + # Setting defaults for baichuan custom config which doesn't apply them. + config.rope_theta = getattr(config, "rope_theta", 10000) + config.num_key_value_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads + ) self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, @@ -208,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d2544155..396969cd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( + Seqlen, paged_attention, attention, reshape_and_cache, @@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -514,7 +515,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2e839d15..2d6a7f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b87fd4ca..33aebc2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 3f445f97..f237ea37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 69f38c3a..1cc6a613 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 04d4ba51..e7614232 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index badfc367..30989a37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index f6a2e15d..a0273c37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 90c2079f..1a51ee76 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -30,9 +30,12 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, + FLASH_DECODING, + BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, ) +from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments @@ -45,7 +48,6 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = 16 # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -855,7 +857,23 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if SYSTEM == "ipex" and device == torch.device("cpu"): + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + elif SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( @@ -907,6 +925,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } + input_lengths_ = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -919,7 +938,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -927,6 +946,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1066,6 +1086,7 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + input_lengths = Seqlen(input_lengths=input_lengths) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1152,6 +1173,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index aa1ae9ac..7e2b8780 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma is only available on GPU") diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py index 9608113b..86cfc7e2 100644 --- a/server/text_generation_server/models/flash_gemma2.py +++ b/server/text_generation_server/models/flash_gemma2.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 209eca83..0f5746de 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -153,7 +153,7 @@ class BaseFlashMistral(FlashCausalLM): # TODO: this is a hack to avoid the gate_proj for # FlashStarcoder2 that doesnt have these layers - if hasattr(layer.mlp, "gate_up_proj"): + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 7e108d05..a530d1c3 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashPhi is only available on GPU") diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 23528f0b..cd6078f1 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashQwen2 is only available on GPU") diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 3658c626..5094f477 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,12 @@ from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +if FLASH_DECODING: + logger.info("Using FLASH_DECODING") + + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: @@ -15,8 +21,6 @@ if cuda_graphs is not None: ) else: cuda_graphs = None - - # sorting the cuda graphs in descending order helps reduce the # memory impact and results in less memory usage if cuda_graphs is not None: diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 6d921721..011e0f63 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,6 +1,7 @@ import torch from loguru import logger import subprocess +import os def is_ipex_available(): @@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - query = f"xpu-smi dump -d {device_id} -m 18 -n 1" - output = subprocess.check_output(query.split()).decode("utf-8").split("\n") - used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 - free_memory = int(total_memory * 0.95 - used_memory) + memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + free_memory = max( + 0, + int( + total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) + ), + ) return free_memory