Merge branch 'main' into mi300-compat

This commit is contained in:
fxmarty 2024-05-16 11:02:51 +02:00
commit f219124711
35 changed files with 2290 additions and 164 deletions

View File

@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
AWS_REGION: us-east-1 AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002 EC2_AMI_ID: ami-0789b6925c11b1fb2
EC2_INSTANCE_TYPE: g5.12xlarge EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6 EC2_SECURITY_GROUP: sg-030175c435ac141d6

View File

@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.1
ARG MAMBA_VERSION=23.3.1-1 ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx # Automatically set by buildx
@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
ca-certificates \ ca-certificates \
make \ make \
curl \ curl \
git \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed # Copy conda with PyTorch installed

View File

@ -14,5 +14,10 @@
__version__ = "0.6.0" __version__ = "0.6.0"
DEPRECATION_WARNING = (
"`text_generation` clients are deprecated and will be removed in the near future. "
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -1,10 +1,12 @@
import json import json
import requests import requests
import warnings
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation import DEPRECATION_WARNING
from text_generation.types import ( from text_generation.types import (
StreamResponse, StreamResponse,
Response, Response,
@ -19,6 +21,9 @@ from text_generation.types import (
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
# emit deprecation warnings
warnings.simplefilter("always", DeprecationWarning)
class Client: class Client:
"""Client to make calls to a text-generation-inference instance """Client to make calls to a text-generation-inference instance
@ -59,6 +64,7 @@ class Client:
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
""" """
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
self.base_url = base_url self.base_url = base_url
self.headers = headers self.headers = headers
self.cookies = cookies self.cookies = cookies
@ -449,6 +455,7 @@ class AsyncClient:
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
""" """
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
self.base_url = base_url self.base_url = base_url
self.headers = headers self.headers = headers
self.cookies = cookies self.cookies = cookies

View File

@ -9,6 +9,7 @@ The following models are optimized and can be served with TGI, which uses custom
- [BLOOM](https://huggingface.co/bigscience/bloom) - [BLOOM](https://huggingface.co/bigscience/bloom)
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) - [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b) - [Galactica](https://huggingface.co/facebook/galactica-120b)
- [GPT-2](https://huggingface.co/openai-community/gpt2)
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) - [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Llama](https://github.com/facebookresearch/llama) - [Llama](https://github.com/facebookresearch/llama)
- [OPT](https://huggingface.co/facebook/opt-66b) - [OPT](https://huggingface.co/facebook/opt-66b)

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1835938,
"text": " is"
},
{
"id": 2769,
"logprob": -9.171875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6425781,
"text": " learning"
},
{
"id": 30,
"logprob": -0.7314453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.68603516,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.005393982,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.31079102,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08300781,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.58984375,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.953125,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0957031,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8095703,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9375,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
}

View File

@ -0,0 +1,398 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1835938,
"text": " is"
},
{
"id": 2769,
"logprob": -9.171875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6425781,
"text": " learning"
},
{
"id": 30,
"logprob": -0.7314453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.68603516,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.005672455,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3251953,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08294678,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5854492,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9423828,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0800781,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8369141,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0683594,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9711914,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
}
]

View File

@ -0,0 +1,25 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 54901,
"logprob": -0.72753906,
"special": false,
"text": "beach"
},
{
"id": 1,
"logprob": -0.011009216,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "beach"
}

View File

@ -0,0 +1,44 @@
import pytest
@pytest.fixture(scope="module")
def flash_gpt2_handle(launcher):
with launcher("openai-community/gpt2", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gpt2(flash_gpt2_handle):
await flash_gpt2_handle.health(300)
return flash_gpt2_handle.client
@pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate(
"What is deep learning?",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load(
flash_gpt2,
"What is deep learning?",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert all(
[text == generated_texts[0] for text in generated_texts]
), generated_texts
assert responses == response_snapshot

View File

@ -0,0 +1,39 @@
import pytest
import requests
import io
import base64
@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma-3b-pt-224",
num_shard=1,
revision="float16",
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"![]({cow})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach"
assert response == response_snapshot

View File

@ -100,7 +100,6 @@ impl LlavaNext {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct ClipVisionModel { pub struct ClipVisionModel {
image_size: usize, image_size: usize,
@ -108,7 +107,6 @@ pub struct ClipVisionModel {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 {} pub struct Idefics2 {}
@ -118,6 +116,24 @@ impl Idefics2 {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
num_image_tokens: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
}
impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
self.text_config.num_image_tokens
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -132,6 +148,7 @@ pub enum Config {
Santacoder, Santacoder,
Bloom, Bloom,
Mpt, Mpt,
Gpt2,
GptNeox, GptNeox,
Phi, Phi,
#[serde(rename = "phi-msft")] #[serde(rename = "phi-msft")]
@ -139,6 +156,7 @@ pub enum Config {
Phi3, Phi3,
Llama, Llama,
Baichuan, Baichuan,
Paligemma(Paligemma),
Gemma, Gemma,
Cohere, Cohere,
Drbx, Drbx,

View File

@ -979,24 +979,28 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1049,30 +1053,35 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi again!".to_string()), content: Some("Hi again!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1130,24 +1139,28 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1189,24 +1202,28 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1234,18 +1251,21 @@ mod tests {
content: Some("Hello, how are you?".to_string()), content: Some("Hello, how are you?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("I'm doing great. How can I help you today?".to_string()), content: Some("I'm doing great. How can I help you today?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("I'd like to show off how chat templating works!".to_string()), content: Some("I'd like to show off how chat templating works!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
]; ];
@ -1257,6 +1277,7 @@ mod tests {
), ),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}] }]
.iter() .iter()
.chain(&example_chat) .chain(&example_chat)
@ -1401,12 +1422,14 @@ mod tests {
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("How many helicopters can a human eat in one sitting?".to_string()), content: Some("How many helicopters can a human eat in one sitting?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
add_generation_prompt: true, add_generation_prompt: true,

View File

@ -546,6 +546,7 @@ impl ChatCompletion {
content: output, content: output,
name: None, name: None,
tool_calls, tool_calls,
tool_call_id: None,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@ -881,7 +882,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall { pub(crate) struct ToolCall {
pub id: u32, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
@ -954,13 +955,16 @@ pub(crate) struct Message {
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")] #[serde(default, deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>, pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"get_weather\"")]
pub tool_call_id: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -988,7 +988,6 @@ async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let ChatRequest { let ChatRequest {
logprobs, logprobs,
max_tokens, max_tokens,
@ -1160,7 +1159,7 @@ async fn chat_completions(
) )
})?; })?;
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: 0, id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,

View File

@ -544,6 +544,30 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics2(config)) => { Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());

225
server/poetry.lock generated
View File

@ -359,43 +359,45 @@ files = [
[[package]] [[package]]
name = "datasets" name = "datasets"
version = "2.14.4" version = "2.19.1"
description = "HuggingFace community-driven open-source library of datasets" description = "HuggingFace community-driven open-source library of datasets"
optional = true optional = true
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, {file = "datasets-2.19.1-py3-none-any.whl", hash = "sha256:f7a78d15896f45004ccac1c298f3c7121f92f91f6f2bfbd4e4f210f827e6e411"},
{file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, {file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"},
] ]
[package.dependencies] [package.dependencies]
aiohttp = "*" aiohttp = "*"
dill = ">=0.3.0,<0.3.8" dill = ">=0.3.0,<0.3.9"
fsspec = {version = ">=2021.11.1", extras = ["http"]} filelock = "*"
huggingface-hub = ">=0.14.0,<1.0.0" fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
huggingface-hub = ">=0.21.2"
multiprocess = "*" multiprocess = "*"
numpy = ">=1.17" numpy = ">=1.17"
packaging = "*" packaging = "*"
pandas = "*" pandas = "*"
pyarrow = ">=8.0.0" pyarrow = ">=12.0.0"
pyarrow-hotfix = "*"
pyyaml = ">=5.1" pyyaml = ">=5.1"
requests = ">=2.19.0" requests = ">=2.19.0"
tqdm = ">=4.62.1" tqdm = ">=4.62.1"
xxhash = "*" xxhash = "*"
[package.extras] [package.extras]
apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] apache-beam = ["apache-beam (>=2.26.0)"]
audio = ["librosa", "soundfile (>=0.12.1)"] audio = ["librosa", "soundfile (>=0.12.1)"]
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] quality = ["ruff (>=0.3.0)"]
s3 = ["s3fs"] s3 = ["s3fs"]
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] tensorflow = ["tensorflow (>=2.6.0)"]
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] tensorflow-gpu = ["tensorflow (>=2.6.0)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
torch = ["torch"] torch = ["torch"]
vision = ["Pillow (>=6.2.1)"] vision = ["Pillow (>=6.2.1)"]
@ -418,17 +420,18 @@ dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]] [[package]]
name = "dill" name = "dill"
version = "0.3.7" version = "0.3.8"
description = "serialize all of Python" description = "serialize all of Python"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.8"
files = [ files = [
{file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
{file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
] ]
[package.extras] [package.extras]
graph = ["objgraph (>=1.7.2)"] graph = ["objgraph (>=1.7.2)"]
profile = ["gprof2dot (>=2022.7.29)"]
[[package]] [[package]]
name = "diskcache" name = "diskcache"
@ -871,13 +874,13 @@ files = [
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.19.4" version = "0.23.0"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"}, {file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
{file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"}, {file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
] ]
[package.dependencies] [package.dependencies]
@ -890,16 +893,17 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3" typing-extensions = ">=3.7.4.3"
[package.extras] [package.extras]
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"] cli = ["InquirerPy (==0.3.4)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] hf-transfer = ["hf-transfer (>=0.1.4)"]
quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] inference = ["aiohttp", "minijinja (>=1.0)"]
quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"] tensorflow = ["graphviz", "pydot", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] tensorflow-testing = ["keras (<3.0)", "tensorflow"]
torch = ["torch"] testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
torch = ["safetensors", "torch"]
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]] [[package]]
@ -1282,31 +1286,27 @@ files = [
[[package]] [[package]]
name = "multiprocess" name = "multiprocess"
version = "0.70.15" version = "0.70.16"
description = "better multiprocessing and multithreading in Python" description = "better multiprocessing and multithreading in Python"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.8"
files = [ files = [
{file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
{file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
{file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
{file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
{file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
{file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
{file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
{file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
{file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
{file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
{file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
{file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
{file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"},
{file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"},
{file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"},
{file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"},
] ]
[package.dependencies] [package.dependencies]
dill = ">=0.3.7" dill = ">=0.3.8"
[[package]] [[package]]
name = "nest-asyncio" name = "nest-asyncio"
@ -2034,52 +2034,63 @@ files = [
[[package]] [[package]]
name = "pyarrow" name = "pyarrow"
version = "16.0.0" version = "16.1.0"
description = "Python library for Apache Arrow" description = "Python library for Apache Arrow"
optional = true optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyarrow-16.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:22a1fdb1254e5095d629e29cd1ea98ed04b4bbfd8e42cc670a6b639ccc208b60"}, {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
{file = "pyarrow-16.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:574a00260a4ed9d118a14770edbd440b848fcae5a3024128be9d0274dbcaf858"}, {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0815d0ddb733b8c1b53a05827a91f1b8bde6240f3b20bf9ba5d650eb9b89cdf"}, {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df0080339387b5d30de31e0a149c0c11a827a10c82f0c67d9afae3981d1aabb7"}, {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:edf38cce0bf0dcf726e074159c60516447e4474904c0033f018c1f33d7dac6c5"}, {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:91d28f9a40f1264eab2af7905a4d95320ac2f287891e9c8b0035f264fe3c3a4b"}, {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
{file = "pyarrow-16.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:99af421ee451a78884d7faea23816c429e263bd3618b22d38e7992c9ce2a7ad9"}, {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
{file = "pyarrow-16.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d22d0941e6c7bafddf5f4c0662e46f2075850f1c044bf1a03150dd9e189427ce"}, {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
{file = "pyarrow-16.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:266ddb7e823f03733c15adc8b5078db2df6980f9aa93d6bb57ece615df4e0ba7"}, {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cc23090224b6594f5a92d26ad47465af47c1d9c079dd4a0061ae39551889efe"}, {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56850a0afe9ef37249d5387355449c0f94d12ff7994af88f16803a26d38f2016"}, {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:705db70d3e2293c2f6f8e84874b5b775f690465798f66e94bb2c07bab0a6bb55"}, {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:5448564754c154997bc09e95a44b81b9e31ae918a86c0fcb35c4aa4922756f55"}, {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
{file = "pyarrow-16.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:729f7b262aa620c9df8b9967db96c1575e4cfc8c25d078a06968e527b8d6ec05"}, {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
{file = "pyarrow-16.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:fb8065dbc0d051bf2ae2453af0484d99a43135cadabacf0af588a3be81fbbb9b"}, {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
{file = "pyarrow-16.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:20ce707d9aa390593ea93218b19d0eadab56390311cb87aad32c9a869b0e958c"}, {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5823275c8addbbb50cd4e6a6839952682a33255b447277e37a6f518d6972f4e1"}, {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab8b9050752b16a8b53fcd9853bf07d8daf19093533e990085168f40c64d978"}, {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:42e56557bc7c5c10d3e42c3b32f6cff649a29d637e8f4e8b311d334cc4326730"}, {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a7abdee4a4a7cfa239e2e8d721224c4b34ffe69a0ca7981354fe03c1328789b"}, {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
{file = "pyarrow-16.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:ef2f309b68396bcc5a354106741d333494d6a0d3e1951271849787109f0229a6"}, {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
{file = "pyarrow-16.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ed66e5217b4526fa3585b5e39b0b82f501b88a10d36bd0d2a4d8aa7b5a48e2df"}, {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
{file = "pyarrow-16.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc8814310486f2a73c661ba8354540f17eef51e1b6dd090b93e3419d3a097b3a"}, {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c2f5e239db7ed43e0ad2baf46a6465f89c824cc703f38ef0fde927d8e0955f7"}, {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f293e92d1db251447cb028ae12f7bc47526e4649c3a9924c8376cab4ad6b98bd"}, {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:dd9334a07b6dc21afe0857aa31842365a62eca664e415a3f9536e3a8bb832c07"}, {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d91073d1e2fef2c121154680e2ba7e35ecf8d4969cc0af1fa6f14a8675858159"}, {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
{file = "pyarrow-16.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:71d52561cd7aefd22cf52538f262850b0cc9e4ec50af2aaa601da3a16ef48877"}, {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
{file = "pyarrow-16.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:b93c9a50b965ee0bf4fef65e53b758a7e8dcc0c2d86cebcc037aaaf1b306ecc0"}, {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
{file = "pyarrow-16.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d831690844706e374c455fba2fb8cfcb7b797bfe53ceda4b54334316e1ac4fa4"}, {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35692ce8ad0b8c666aa60f83950957096d92f2a9d8d7deda93fb835e6053307e"}, {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dd3151d098e56f16a8389c1247137f9e4c22720b01c6f3aa6dec29a99b74d80"}, {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bd40467bdb3cbaf2044ed7a6f7f251c8f941c8b31275aaaf88e746c4f3ca4a7a"}, {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:00a1dcb22ad4ceb8af87f7bd30cc3354788776c417f493089e0a0af981bc8d80"}, {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
{file = "pyarrow-16.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fda9a7cebd1b1d46c97b511f60f73a5b766a6de4c5236f144f41a5d5afec1f35"}, {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
{file = "pyarrow-16.0.0.tar.gz", hash = "sha256:59bb1f1edbbf4114c72415f039f1359f1a57d166a331c3229788ccbfbb31689a"}, {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
] ]
[package.dependencies] [package.dependencies]
numpy = ">=1.16.6" numpy = ">=1.16.6"
[[package]]
name = "pyarrow-hotfix"
version = "0.6"
description = ""
optional = true
python-versions = ">=3.5"
files = [
{file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"},
{file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
]
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.7.1" version = "2.7.1"
@ -3016,18 +3027,16 @@ telegram = ["requests"]
[[package]] [[package]]
name = "transformers" name = "transformers"
version = "4.40.2" version = "4.41.0.dev0"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = []
{file = "transformers-4.40.2-py3-none-any.whl", hash = "sha256:71cb94301ec211a2e1d4b8c8d18dcfaa902dfa00a089dceca167a8aa265d6f2d"}, develop = false
{file = "transformers-4.40.2.tar.gz", hash = "sha256:657b6054a2097671398d976ad46e60836e7e15f9ea9551631a96e33cb9240649"},
]
[package.dependencies] [package.dependencies]
filelock = "*" filelock = "*"
huggingface-hub = ">=0.19.3,<1.0" huggingface-hub = ">=0.23.0,<1.0"
numpy = ">=1.17" numpy = ">=1.17"
packaging = ">=20.0" packaging = ">=20.0"
pyyaml = ">=5.1" pyyaml = ">=5.1"
@ -3040,27 +3049,25 @@ tqdm = ">=4.27"
[package.extras] [package.extras]
accelerate = ["accelerate (>=0.21.0)"] accelerate = ["accelerate (>=0.21.0)"]
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] all = ["Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
codecarbon = ["codecarbon (==1.2.0)"] codecarbon = ["codecarbon (==1.2.0)"]
deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] dev = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] dev-tensorflow = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] dev-torch = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
docs-specific = ["hf-doc-builder"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"] ftfy = ["ftfy"]
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)"]
modelcreation = ["cookiecutter (==1.7.3)"] modelcreation = ["cookiecutter (==1.7.3)"]
natten = ["natten (>=0.14.6,<0.15.0)"] natten = ["natten (>=0.14.6,<0.15.0)"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"] optuna = ["optuna"]
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"]
ray = ["ray[tune] (>=2.7.0)"] ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
sagemaker = ["sagemaker (>=2.31.0)"] sagemaker = ["sagemaker (>=2.31.0)"]
@ -3069,19 +3076,25 @@ serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
sigopt = ["sigopt"] sigopt = ["sigopt"]
sklearn = ["scikit-learn"] sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
timm = ["timm"] timm = ["timm"]
tokenizers = ["tokenizers (>=0.19,<0.20)"] tokenizers = ["tokenizers (>=0.19,<0.20)"]
torch = ["accelerate (>=0.21.0)", "torch"] torch = ["accelerate (>=0.21.0)", "torch"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib_metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"]
video = ["av (==9.2.0)", "decord (==0.6.0)"] video = ["av (==9.2.0)", "decord (==0.6.0)"]
vision = ["Pillow (>=10.0.1,<=15.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"]
[package.source]
type = "git"
url = "https://github.com/huggingface/transformers.git"
reference = "b8aee2e"
resolved_reference = "b8aee2e918d7ba2d5e9e80162ae26b4806873307"
[[package]] [[package]]
name = "triton" name = "triton"
version = "2.3.0" version = "2.3.0"
@ -3488,4 +3501,4 @@ torch = ["torch"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.13" python-versions = ">=3.9,<3.13"
content-hash = "df83b265d0263870b5d1ae8bfd847f406abef90868fdf528ff38527b512f86c0" content-hash = "b2a29b0b6e32d0e7043e94b984c5731f2c27c5d95feccbeb80bd890db22d6c4a"

View File

@ -25,8 +25,9 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.19.1" tokenizers = "^0.19.1"
huggingface-hub = "^0.19.3" huggingface-hub = "^0.23"
transformers = "^4.40" # transformers = "^4.40"
transformers = { git = "https://github.com/huggingface/transformers.git", rev="b8aee2e" }
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }

View File

@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -70,7 +70,7 @@ class Linear8bitLt(torch.nn.Module):
return out return out
class Linear4bit(nn.Module): class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type): def __init__(self, weight, bias, quant_type):
super().__init__() super().__init__()
self.weight = Params4bit( self.weight = Params4bit(

View File

@ -15,9 +15,9 @@ class FastLinear(torch.nn.Module):
bias, bias,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = torch.nn.Parameter(weight) self.weight = torch.nn.Parameter(weight, requires_grad=False)
if bias is not None: if bias is not None:
self.bias = torch.nn.Parameter(bias) self.bias = torch.nn.Parameter(bias, requires_grad=False)
else: else:
self.bias = None self.bias = None

View File

@ -51,6 +51,7 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
@ -64,6 +65,9 @@ try:
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.flash_gemma import (
FlashGemma, FlashGemma,
) )
from text_generation_server.models.pali_gemma import (
PaliGemma,
)
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded, FlashSantacoderSharded,
) )
@ -83,6 +87,7 @@ except ImportError as e:
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
@ -325,7 +330,27 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "gpt2":
if FLASH_ATTENTION:
return FlashGPT2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt_neox": elif model_type == "gpt_neox":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashNeoXSharded( return FlashNeoXSharded(
@ -654,6 +679,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma":
if FLASH_ATTENTION:
return PaliGemma(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next": if model_type == "llava_next":
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -99,8 +99,13 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps) weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after # perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
@ -111,7 +116,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight hidden_states = hidden_states * self.weight
return hidden_states.to(self.weight.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -153,15 +158,11 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
def __init__( def __init__(self, prefix: str, config, weights, causal: bool):
self,
prefix: str,
config,
weights,
):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_size = config.head_dim self.head_size = config.head_dim
self.causal = causal
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
@ -238,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
causal=self.causal,
) )
# Decode # Decode
else: else:
@ -295,11 +297,10 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
) )
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -351,30 +352,25 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
embed_norm = config.hidden_size**0.5
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashGemmaLayer( FlashGemmaLayer(
layer_id, prefix=f"{prefix}.layers.{layer_id}",
config, config=config,
weights, weights=weights,
causal=causal,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = GemmaFastRMSNorm.load( self.norm = GemmaFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -385,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -394,7 +390,7 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -423,13 +419,30 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
self.model = FlashGemmaModel(config, weights) embed_norm = config.hidden_size**0.5
if prefix is None:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix=(
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights, weights=weights,
) )
@ -445,8 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,

View File

@ -0,0 +1,454 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
def load_qkv(config, prefix: str, weights, head_size, num_heads):
if config.quantize == "gptq":
return _load_qkv_gptq(
config,
prefix,
weights,
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
def _load_qkv_gptq(config, prefix: str, weights):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
# Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
single_size = total_size // 3
assert single_size % world_size == 0
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
tensors.append(tensor)
bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
"""Load QKV from a single, transposed matrix."""
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
total_size = shape[1]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
world_size = weights.process_group.size()
single_size = total_size // 3
assert single_size % world_size == 0
rank = weights.process_group.rank()
# Weights
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[:, start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=1).T
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
single_size = total_size // 3
block_size = single_size // world_size
assert single_size % world_size == 0
start = rank * block_size
stop = (rank + 1) * block_size
b = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
b.append(tensor)
bias = torch.cat(b, dim=0)
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
assert list(bias.shape) == [
3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=1
)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashGPT2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = load_qkv(
config,
prefix=prefix,
weights=weights,
head_size=self.head_size,
num_heads=self.num_heads,
)
self.o_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
intermediate_size = (
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
)
self.intermediate_size = intermediate_size // weights.process_group.size()
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_2",
weights=weights,
eps=config.layer_norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
return residual + mlp_output, residual
class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGPT2Layer(
prefix=(
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
),
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = nn.LayerNorm.load(
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=("wte" if not prefix else f"{prefix}.wte"),
weights=weights,
)
self.embed_positions = TensorParallelEmbedding(
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
weights=weights,
)
self.model = FlashGPT2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="wte" if not prefix else f"{prefix}.wte",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
inputs_embeds = token_embeds + position_embeds
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = TensorParallelColumnLinear.load(
config,
prefix="multi_modal_projector.linear",
weights=weights,
bias=True,
)
self.vocab_size = config.vocab_size
self.config = config
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,565 @@
from typing import Optional, Tuple, Union
import math
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
patch_embeds = self.patch_embedding(
pixel_values
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.head_size = self.head_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
)
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
# scale post matmul
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(attn_weights.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, prefix, config: SiglipConfig, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if output_attentions:
return hidden_states, attn_weights
return hidden_states, None
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(prefix, config, weights)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
import warnings
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
from torch.nn.init import _calculate_fan_in_and_fan_out
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
from transformers import PreTrainedModel
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, SiglipConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, prefix, config: SiglipConfig, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
# NOTE: up until this point, the code logits are exactly
# the same as the transformers code. The values evaulate
# slightly differently in our encoder layer.
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
post_last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling(
last_hidden_state=post_last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
)

View File

@ -11,6 +11,18 @@ def load_text_model(prefix, config, weights, name=None):
) )
return FlashMistralForCausalLM(prefix, config, weights, name=name) return FlashMistralForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "gemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
@ -24,5 +36,13 @@ def load_vision_model(prefix, config, weights):
return CLIPVisionTransformer( return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights prefix=f"{prefix}.vision_model", config=config, weights=weights
) )
if config.model_type == "siglip_vision_model":
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
return SiglipVisionTransformer(
prefix=f"vision_tower.vision_model", config=config, weights=weights
)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -135,6 +135,17 @@ class FlashCausalLMBatch(Batch):
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
position_ids = [] position_ids = []
speculative_ids = [] speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
@ -209,6 +220,7 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks
@ -760,7 +772,7 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
empty_cache() empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
batch.blocks, batch.blocks,

View File

@ -3,12 +3,11 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
GemmaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -38,17 +37,15 @@ class FlashGemma(FlashCausalLM):
else: else:
raise NotImplementedError("FlashGemma is only available on GPU") raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = GemmaTokenizerFast.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
) )
config = GemmaConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
@ -61,7 +58,9 @@ class FlashGemma(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights) # TODO hardcoded
prefix = "language_model"
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__( super(FlashGemma, self).__init__(

View File

@ -0,0 +1,78 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -0,0 +1,123 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
load_data_uri,
split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -15,6 +15,7 @@ from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
FlashMistralBatch, FlashMistralBatch,
) )
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
) )
@ -80,6 +81,9 @@ def image_text_replacement(image_input, config, image_id) -> str:
logger.info(f"Found {num_features} in image of resolution {height}x{width}") logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features return "<image>" * num_features
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -193,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"] )["input_ids"]
if image_inputs: if image_inputs:
image_input = image_inputs[0] image_input = image_inputs[0]

View File

@ -14,7 +14,10 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
@ -98,6 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in { if self.model.batch_type in {
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
@ -122,6 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in { if self.model.batch_type in {
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,

View File

@ -131,6 +131,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
@ -149,7 +150,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
window_size_left, window_size_left,
0, 0,
False, False,