mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into ci_amd3
This commit is contained in:
commit
29a416078c
54
.github/workflows/build.yaml
vendored
54
.github/workflows/build.yaml
vendored
@ -16,7 +16,6 @@ jobs:
|
||||
build-and-push:
|
||||
outputs:
|
||||
docker_image: ${{ steps.final.outputs.docker_image }}
|
||||
base_docker_image: ${{ steps.final.outputs.base_docker_image }}
|
||||
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
||||
docker_volume: ${{ steps.final.outputs.docker_volume}}
|
||||
runs_on: ${{ steps.final.outputs.runs_on }}
|
||||
@ -73,17 +72,13 @@ jobs:
|
||||
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||
|
||||
- name: Tailscale
|
||||
uses: huggingface/tailscale-action@main
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
|
||||
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
- name: Initialize Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
install: true
|
||||
config-inline: |
|
||||
[registry."docker.io"]
|
||||
mirrors = ["registry.github-runners.huggingface.tech"]
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
@ -93,13 +88,6 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Login to internal Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
||||
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
||||
registry: registry.internal.huggingface.tech
|
||||
|
||||
- name: Login to Azure Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
@ -115,10 +103,9 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference
|
||||
tags: |
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||
|
||||
# If main, release or tag
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
@ -128,7 +115,7 @@ jobs:
|
||||
flavor: |
|
||||
latest=auto
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference
|
||||
ghcr.io/huggingface/text-generation-inference
|
||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||
tags: |
|
||||
@ -136,7 +123,6 @@ jobs:
|
||||
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }}
|
||||
type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
id: build-and-push
|
||||
uses: docker/build-push-action@v4
|
||||
@ -150,30 +136,16 @@ jobs:
|
||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
||||
|
||||
cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
||||
cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
||||
- name: Final
|
||||
id: final
|
||||
run: |
|
||||
echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if [[ ${{ inputs.hardware }} == "rocm" ]]
|
||||
then
|
||||
echo "base_docker_image=rocm/dev-ubuntu-22.04:6.1.1_hip_update" >> "$GITHUB_OUTPUT"
|
||||
elif [[ ${{ inputs.hardware }} == "cuda" ]]
|
||||
then
|
||||
echo "base_docker_image=nvidia/cuda:12.1.0-base-ubuntu22.04" >> "$GITHUB_OUTPUT"
|
||||
elif [[ ${{ inputs.hardware }} == "xpu" ]]
|
||||
then
|
||||
echo "base_docker_image=intel/intel-extension-for-pytorch:2.1.30-xpu" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ${{ inputs.hardware }} == "rocm" ]]
|
||||
then
|
||||
echo "docker_volume=/data/cache/.cache/huggingface/hub" >> "$GITHUB_OUTPUT"
|
||||
@ -191,7 +163,7 @@ jobs:
|
||||
# Ideally, we would use the image from registry.internal.huggingface.tech but we can not login to the private registry outside of tailscale,
|
||||
# and even adding a previous job with tailscale login still results in `Docker login for 'registry.internal.huggingface.tech' failed with exit code 1`.
|
||||
container:
|
||||
image: ${{ needs.build-and-push.outputs.base_docker_image }}
|
||||
image: ${{ needs.build-and-push.outputs.docker_image }}
|
||||
options: --shm-size "16gb" --ipc host -v ${{ needs.build-and-push.outputs.docker_volume }}:/data
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@ -207,8 +179,6 @@ jobs:
|
||||
echo "ls:"
|
||||
ls
|
||||
|
||||
pip3 install -U huggingface_hub
|
||||
|
||||
python3 integration-tests/clean_cache_and_download.py --token ${{ secrets.HF_TOKEN }} --cache-dir /data
|
||||
|
||||
# Avoid permissions issues in the next step not run within docker (File was unable to be removed Error: EACCES).
|
||||
@ -242,12 +212,6 @@ jobs:
|
||||
run: |
|
||||
make install-integration-tests
|
||||
|
||||
- name: Tailscale
|
||||
uses: huggingface/tailscale-action@main
|
||||
if: needs.build-and-push.outputs.runs_on != 'amd-gpu-tgi'
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
||||
|
@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||
|
||||
# Install server
|
||||
@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install triton
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck;
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
|
||||
};
|
||||
use crate::{
|
||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
@ -270,7 +272,11 @@ struct ChatTemplate {
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
fn new(
|
||||
template: String,
|
||||
bos_token: Option<TokenizerConfigToken>,
|
||||
eos_token: Option<TokenizerConfigToken>,
|
||||
) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
@ -287,8 +293,8 @@ impl ChatTemplate {
|
||||
|
||||
Self {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
@ -301,9 +307,9 @@ impl ChatTemplate {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text(Text {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -340,6 +346,14 @@ impl ToolGrammar {
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
let tool = req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == function.name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
|
||||
.clone();
|
||||
vec![tool]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
|
@ -39,7 +39,14 @@ impl SchedulerV2 {
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
// Infer shared state
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
|
@ -39,9 +39,15 @@ impl SchedulerV3 {
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -53,23 +53,40 @@ pub enum ChatTemplateVersions {
|
||||
Multiple(Vec<ChatTemplate>),
|
||||
}
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
pub completion_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub eos_token: Option<String>,
|
||||
pub bos_token: Option<TokenizerConfigToken>,
|
||||
pub eos_token: Option<TokenizerConfigToken>,
|
||||
pub tokenizer_class: Option<String>,
|
||||
pub add_bos_token: Option<bool>,
|
||||
pub add_eos_token: Option<bool>,
|
||||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.and_then(|content| serde_json::from_str(&content).ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum TokenizerConfigToken {
|
||||
String(String),
|
||||
Object { content: String },
|
||||
}
|
||||
|
||||
impl TokenizerConfigToken {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
TokenizerConfigToken::String(s) => s,
|
||||
TokenizerConfigToken::Object { content } => content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,9 +117,10 @@ pub struct HubProcessorConfig {
|
||||
}
|
||||
|
||||
impl HubProcessorConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.and_then(|content| serde_json::from_str(&content).ok())
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,35 +139,6 @@ pub(crate) enum GrammarType {
|
||||
Regex(String),
|
||||
}
|
||||
|
||||
mod token_serde {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(s)),
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
|
||||
Ok(Some(content.to_string()))
|
||||
} else {
|
||||
Err(de::Error::custom(
|
||||
"content key not found in structured token",
|
||||
))
|
||||
}
|
||||
}
|
||||
Value::Null => Ok(None),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters {
|
||||
}
|
||||
}
|
||||
|
||||
mod prompt_serde {
|
||||
use serde::{self, Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[serde(try_from = "PromptDeserializer")]
|
||||
pub struct Prompt(pub Vec<String>);
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum PromptDeserializer {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
impl TryFrom<PromptDeserializer> for Prompt {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: PromptDeserializer) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
Value::String(s) => Ok(vec![s]),
|
||||
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
||||
"Empty array detected. Do not use an empty array for the prompt.",
|
||||
)),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
Value::String(s) => Ok(s.to_owned()),
|
||||
_ => Err(serde::de::Error::custom("Expected a string")),
|
||||
})
|
||||
.collect(),
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"Expected a string or an array of strings",
|
||||
)),
|
||||
PromptDeserializer::Single(s) => Ok(Prompt(vec![s])),
|
||||
PromptDeserializer::Multiple(v) => {
|
||||
if v.is_empty() {
|
||||
Err(
|
||||
"Empty array detected. Do not use an empty array for the prompt."
|
||||
.to_string(),
|
||||
)
|
||||
} else {
|
||||
Ok(Prompt(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -396,8 +388,7 @@ pub struct CompletionRequest {
|
||||
|
||||
/// The prompt to generate completions for.
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
#[serde(deserialize_with = "prompt_serde::deserialize")]
|
||||
pub prompt: Vec<String>,
|
||||
pub prompt: Prompt,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
@ -445,7 +436,6 @@ pub struct CompletionRequest {
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct Completion {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
#[schema(example = "1706270835")]
|
||||
pub created: u64,
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete {
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletion {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
#[schema(example = "1706270835")]
|
||||
pub created: u64,
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
@ -562,6 +551,15 @@ pub(crate) struct Usage {
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
#[serde(tag = "object")]
|
||||
enum CompletionType {
|
||||
#[serde(rename = "chat.completion.chunk")]
|
||||
ChatCompletionChunk(ChatCompletionChunk),
|
||||
#[serde(rename = "chat.completion")]
|
||||
ChatCompletion(ChatCompletion),
|
||||
}
|
||||
|
||||
impl ChatCompletion {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
@ -598,7 +596,6 @@ impl ChatCompletion {
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
object: "chat.completion".into(),
|
||||
created,
|
||||
model,
|
||||
system_fingerprint,
|
||||
@ -620,7 +617,6 @@ impl ChatCompletion {
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct CompletionCompleteChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<CompletionComplete>,
|
||||
pub model: String,
|
||||
@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk {
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
#[schema(example = "1706270978")]
|
||||
pub created: u64,
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
@ -710,7 +705,6 @@ impl ChatCompletionChunk {
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model,
|
||||
system_fingerprint,
|
||||
@ -821,7 +815,6 @@ pub(crate) struct ChatRequest {
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option<String> {
|
||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
enum ToolType {
|
||||
FunctionName(String),
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolType {
|
||||
OneOf,
|
||||
FunctionName(String),
|
||||
Function { function: FunctionName },
|
||||
}
|
||||
|
||||
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||
mod deserialize_tool_choice {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct FunctionName {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(from = "ToolTypeDeserializer")]
|
||||
pub struct ToolChoice(pub Option<ToolType>);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ToolTypeDeserializer {
|
||||
None(Option<String>),
|
||||
Some(ToolType),
|
||||
}
|
||||
|
||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
fn from(value: ToolTypeDeserializer) -> Self {
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(None),
|
||||
"auto" => Ok(Some(ToolType::OneOf)),
|
||||
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
|
||||
Some("none") => ToolChoice(None),
|
||||
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
|
||||
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
|
||||
None => ToolChoice(Some(ToolType::OneOf)),
|
||||
},
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map
|
||||
.get("function")
|
||||
.and_then(|v| v.get("name"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||
} else {
|
||||
Err(de::Error::custom("function key not found in tool choice"))
|
||||
}
|
||||
}
|
||||
Value::Null => Ok(Some(ToolType::OneOf)),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -950,26 +940,16 @@ pub(crate) struct ToolCall {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct Url {
|
||||
pub struct Url {
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct ImageUrl {
|
||||
image_url: Url,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct Text {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum MessageChunk {
|
||||
Text(Text),
|
||||
ImageUrl(ImageUrl),
|
||||
pub enum MessageChunk {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: Url },
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
@ -977,35 +957,31 @@ pub struct Message {
|
||||
#[schema(example = "user")]
|
||||
role: String,
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||
content: Vec<MessageChunk>,
|
||||
pub content: MessageContent,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
mod message_content_serde {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
SingleText(String),
|
||||
MultipleChunks(Vec<MessageChunk>),
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Message {
|
||||
Text(String),
|
||||
Chunks(Vec<MessageChunk>),
|
||||
}
|
||||
let message: Message = Deserialize::deserialize(deserializer)?;
|
||||
let chunks = match message {
|
||||
Message::Text(text) => {
|
||||
vec![MessageChunk::Text(Text { text })]
|
||||
// Pushing a chunk to a single text message will convert it to a multiple chunks message
|
||||
impl MessageContent {
|
||||
pub fn push(&mut self, chunk: MessageChunk) {
|
||||
match self {
|
||||
MessageContent::SingleText(text) => {
|
||||
*self =
|
||||
MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]);
|
||||
}
|
||||
Message::Chunks(s) => s,
|
||||
};
|
||||
Ok(chunks)
|
||||
MessageContent::MultipleChunks(chunks) => {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1021,18 +997,17 @@ impl From<Message> for TextMessage {
|
||||
fn from(value: Message) -> Self {
|
||||
TextMessage {
|
||||
role: value.role,
|
||||
content: value
|
||||
.content
|
||||
.into_iter()
|
||||
.map(|c| match c {
|
||||
MessageChunk::Text(Text { text }) => text,
|
||||
MessageChunk::ImageUrl(image) => {
|
||||
let url = image.image_url.url;
|
||||
format!("")
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
content: match value.content {
|
||||
MessageContent::SingleText(text) => text,
|
||||
MessageContent::MultipleChunks(chunks) => chunks
|
||||
.into_iter()
|
||||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1240,9 +1215,16 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
Some(TokenizerConfigToken::String(
|
||||
"<|begin▁of▁sentence|>".to_string()
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
config.eos_token,
|
||||
Some(TokenizerConfigToken::String(
|
||||
"<|end▁of▁sentence|>".to_string()
|
||||
))
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
|
||||
// in this case we expect the tokens to be encoded as structured tokens
|
||||
// we want the content of the structured token
|
||||
@ -1275,9 +1257,16 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
Some(TokenizerConfigToken::Object {
|
||||
content: "<|begin▁of▁sentence|>".to_string()
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
config.eos_token,
|
||||
Some(TokenizerConfigToken::Object {
|
||||
content: "<|end▁of▁sentence|>".to_string()
|
||||
})
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1295,9 +1284,7 @@ mod tests {
|
||||
request.messages[0],
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![MessageChunk::Text(Text {
|
||||
text: "What is Deep Learning?".to_string()
|
||||
}),],
|
||||
content: MessageContent::SingleText("What is Deep Learning?".to_string()),
|
||||
name: None
|
||||
}
|
||||
);
|
||||
@ -1321,10 +1308,10 @@ mod tests {
|
||||
request.messages[0],
|
||||
Message{
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
||||
],
|
||||
content: MessageContent::MultipleChunks(vec![
|
||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
||||
]),
|
||||
name: None
|
||||
}
|
||||
);
|
||||
@ -1334,10 +1321,10 @@ mod tests {
|
||||
fn text_message_convert() {
|
||||
let message = Message{
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
||||
],
|
||||
content: MessageContent::MultipleChunks(vec![
|
||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
||||
]),
|
||||
name: None
|
||||
};
|
||||
let textmsg: TextMessage = message.into();
|
||||
|
@ -553,11 +553,11 @@ pub fn create_post_processor(
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
let bos_token_id = tokenizer
|
||||
.token_to_id(bos)
|
||||
.token_to_id(bos.as_str())
|
||||
.expect("Should have found the bos token id");
|
||||
special_tokens.push((bos.clone(), bos_token_id));
|
||||
single.push(format!("{}:0", bos));
|
||||
pair.push(format!("{}:0", bos));
|
||||
special_tokens.push((bos.as_str(), bos_token_id));
|
||||
single.push(format!("{}:0", bos.as_str()));
|
||||
pair.push(format!("{}:0", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -567,17 +567,17 @@ pub fn create_post_processor(
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(eos)
|
||||
.token_to_id(eos.as_str())
|
||||
.expect("Should have found the eos token id");
|
||||
special_tokens.push((eos.clone(), eos_token_id));
|
||||
single.push(format!("{}:0", eos));
|
||||
pair.push(format!("{}:0", eos));
|
||||
special_tokens.push((eos.as_str(), eos_token_id));
|
||||
single.push(format!("{}:0", eos.as_str()));
|
||||
pair.push(format!("{}:0", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
pair.push(format!("{}:1", bos));
|
||||
pair.push(format!("{}:1", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,7 +585,7 @@ pub fn create_post_processor(
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
pair.push(format!("{}:1", eos));
|
||||
pair.push(format!("{}:1", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -611,14 +611,15 @@ enum RouterError {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use text_generation_router::TokenizerConfigToken;
|
||||
|
||||
#[test]
|
||||
fn test_create_post_processor() {
|
||||
let tokenizer_config = HubTokenizerConfig {
|
||||
add_bos_token: None,
|
||||
add_eos_token: None,
|
||||
bos_token: Some("<s>".to_string()),
|
||||
eos_token: Some("</s>".to_string()),
|
||||
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||
chat_template: None,
|
||||
tokenizer_class: None,
|
||||
completion_template: None,
|
||||
@ -629,9 +630,9 @@ mod tests {
|
||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
||||
|
||||
let expected = TemplateProcessing::builder()
|
||||
.try_single("<s>:0 $A:0 <s>:1")
|
||||
.try_single("<s>:0 $A:0")
|
||||
.unwrap()
|
||||
.try_pair("<s>:0 $A:0 $B:1")
|
||||
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
|
||||
.unwrap()
|
||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
||||
.build()
|
||||
|
@ -12,17 +12,18 @@ use crate::kserve::{
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
|
||||
HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
|
||||
Token, TokenizeResponse, Usage, Validation,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
||||
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
||||
Usage, Validation,
|
||||
};
|
||||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
|
||||
VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -635,7 +636,7 @@ async fn completions(
|
||||
));
|
||||
}
|
||||
|
||||
if req.prompt.len() > info.max_client_batch_size {
|
||||
if req.prompt.0.len() > info.max_client_batch_size {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
@ -651,6 +652,7 @@ async fn completions(
|
||||
|
||||
let generate_requests: Vec<GenerateRequest> = req
|
||||
.prompt
|
||||
.0
|
||||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
@ -705,7 +707,6 @@ async fn completions(
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
|
||||
choices: vec![CompletionComplete {
|
||||
@ -932,7 +933,6 @@ async fn completions(
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
@ -1153,14 +1153,16 @@ async fn chat_completions(
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||
.json_data(CompletionType::ChatCompletionChunk(
|
||||
ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||
),
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
@ -1228,7 +1230,7 @@ async fn chat_completions(
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
// build the complete response object with the full text
|
||||
let response = ChatCompletion::new(
|
||||
let response = CompletionType::ChatCompletion(ChatCompletion::new(
|
||||
model_id,
|
||||
system_fingerprint,
|
||||
output,
|
||||
@ -1236,7 +1238,7 @@ async fn chat_completions(
|
||||
generation.details.unwrap(),
|
||||
logprobs,
|
||||
tool_calls,
|
||||
);
|
||||
));
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(response)).into_response())
|
||||
|
@ -1,6 +1,8 @@
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
import os
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
|
44
server/text_generation_server/layers/attention/common.py
Normal file
44
server/text_generation_server/layers/attention/common.py
Normal file
@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
if FLASH_DECODING:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, input_lengths):
|
||||
self.input_lengths = input_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
cu_seqlen_q = torch.arange(
|
||||
shape[0] + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
self.cu_seqlen_k = cu_seqlen_k
|
||||
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
|
||||
def clamp(self, max):
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
@ -21,7 +23,14 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
if FLASH_DECODING:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
@ -32,7 +41,7 @@ def paged_attention(
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
@ -53,7 +62,8 @@ def paged_attention(
|
||||
#
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# block_size = value_cache.shape[3]
|
||||
block_size = BLOCK_SIZE
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
|
||||
@ -62,58 +72,95 @@ def paged_attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
from vllm._C import ops
|
||||
if FLASH_DECODING:
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
# TODO fixme when flash contains the fix.
|
||||
# Number of splits is not correctly handled
|
||||
# by the current path
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0, # dropout
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
)
|
||||
return out2[0]
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
input_lengths = seqlen.input_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
use_v1 = max_s <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -14,6 +15,7 @@ def attention(
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
@ -28,7 +30,7 @@ def attention(
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
@ -54,10 +56,10 @@ def paged_attention(
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
@ -65,8 +67,9 @@ def paged_attention(
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen.input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
@ -26,7 +28,14 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
if FLASH_DECODING:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
@ -37,7 +46,7 @@ def paged_attention(
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
@ -61,6 +70,7 @@ def paged_attention(
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
@ -119,6 +129,7 @@ def paged_attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
if ENGINE != "triton":
|
||||
|
@ -12,7 +12,6 @@ from pathlib import Path
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
@ -92,6 +92,7 @@ except ImportError as e:
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(FlashGPT2)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
slots,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
residual = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
|
@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -117,6 +118,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
# Setting defaults for baichuan custom config which doesn't apply them.
|
||||
config.rope_theta = getattr(config, "rope_theta", 10000)
|
||||
config.num_key_value_heads = getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
@ -208,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -514,7 +515,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
|
@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -30,9 +30,12 @@ from text_generation_server.models.types import (
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
FLASH_DECODING,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
get_adapter_to_index,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
@ -45,7 +48,6 @@ from text_generation_server.utils.import_utils import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
@ -855,7 +857,23 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
if FLASH_DECODING:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
@ -907,6 +925,7 @@ class FlashCausalLM(Model):
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
@ -919,7 +938,7 @@ class FlashCausalLM(Model):
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths_,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
@ -927,6 +946,7 @@ class FlashCausalLM(Model):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1066,6 +1086,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -1152,6 +1173,7 @@ class FlashCausalLM(Model):
|
||||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||
|
||||
|
@ -153,7 +153,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer.mlp, "gate_up_proj"):
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||
|
||||
|
@ -5,6 +5,12 @@ from typing import Dict
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
logger.info("Using FLASH_DECODING")
|
||||
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
@ -15,8 +21,6 @@ if cuda_graphs is not None:
|
||||
)
|
||||
else:
|
||||
cuda_graphs = None
|
||||
|
||||
|
||||
# sorting the cuda graphs in descending order helps reduce the
|
||||
# memory impact and results in less memory usage
|
||||
if cuda_graphs is not None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
|
||||
def is_ipex_available():
|
||||
@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction):
|
||||
def get_xpu_free_memory(device, memory_fraction):
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
device_id = device.index
|
||||
query = f"xpu-smi dump -d {device_id} -m 18 -n 1"
|
||||
output = subprocess.check_output(query.split()).decode("utf-8").split("\n")
|
||||
used_memory = float(output[1].split(",")[-1]) * 1024 * 1024
|
||||
free_memory = int(total_memory * 0.95 - used_memory)
|
||||
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
|
||||
free_memory = max(
|
||||
0,
|
||||
int(
|
||||
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
|
||||
),
|
||||
)
|
||||
return free_memory
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user