Add llama4 (#3145)

* initial changes

* Add support for other vlm

* cleanup comment

* Improve attn_implementation

* Add comments for support of models

* add model

* add model

* fixes and improvements

* update docker

* Add cache position

* Add tests

* remove redundant changes

* remove tr version

* Upgrade doc + fix linting.

* Fixing the CI.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Mohit Sharma 2025-04-06 13:50:22 +05:30 committed by GitHub
parent 3d059f91ab
commit d9bb9bebc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1893 additions and 61 deletions

View File

@ -65,7 +65,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/ COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin" ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION} RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} pip setuptools packaging RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/ ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/" ENV PATH="$PATH:/usr/src/.venv/bin/"
@ -193,6 +193,9 @@ RUN cd server && \
pwd && \ pwd && \
text-generation-server --help text-generation-server --help
# This shouldn't be necessary.
# RUN uv pip install torchvision --no-deps
# Copy build artifacts from flash attention builder # Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages

View File

@ -9,6 +9,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Llama4](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct) - [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)

View File

@ -0,0 +1,613 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 100,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2721,
"logprob": -0.21582031,
"special": false,
"text": " people"
},
{
"id": 21807,
"logprob": -0.26953125,
"special": false,
"text": " died"
},
{
"id": 310,
"logprob": -0.95703125,
"special": false,
"text": " in"
},
{
"id": 290,
"logprob": -1.3359375,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -1.3828125,
"special": false,
"text": " "
},
{
"id": 7284,
"logprob": -0.011291504,
"special": false,
"text": "191"
},
{
"id": 36,
"logprob": -0.011413574,
"special": false,
"text": "8"
},
{
"id": 18938,
"logprob": -0.23242188,
"special": false,
"text": " flu"
},
{
"id": 27650,
"logprob": -0.0010070801,
"special": false,
"text": " pandemic"
},
{
"id": 26,
"logprob": -0.69140625,
"special": false,
"text": "."
},
{
"id": 114059,
"logprob": -1.4375,
"special": false,
"text": " Estimating"
},
{
"id": 290,
"logprob": -0.24316406,
"special": false,
"text": " the"
},
{
"id": 10593,
"logprob": -0.37304688,
"special": false,
"text": " death"
},
{
"id": 49973,
"logprob": -0.025390625,
"special": false,
"text": " toll"
},
{
"id": 323,
"logprob": -0.27539062,
"special": false,
"text": " of"
},
{
"id": 290,
"logprob": -0.057617188,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -0.040527344,
"special": false,
"text": " "
},
{
"id": 7284,
"logprob": -0.00050735474,
"special": false,
"text": "191"
},
{
"id": 36,
"logprob": -9.298325e-06,
"special": false,
"text": "8"
},
{
"id": 18938,
"logprob": -0.09863281,
"special": false,
"text": " flu"
},
{
"id": 27650,
"logprob": -0.0011749268,
"special": false,
"text": " pandemic"
},
{
"id": 373,
"logprob": -0.32421875,
"special": false,
"text": " is"
},
{
"id": 8210,
"logprob": -0.58203125,
"special": false,
"text": " difficult"
},
{
"id": 2895,
"logprob": -0.40429688,
"special": false,
"text": " because"
},
{
"id": 323,
"logprob": -1.2734375,
"special": false,
"text": " of"
},
{
"id": 49119,
"logprob": -0.51171875,
"special": false,
"text": " incomplete"
},
{
"id": 13308,
"logprob": -0.38085938,
"special": false,
"text": " records"
},
{
"id": 341,
"logprob": -0.55859375,
"special": false,
"text": " and"
},
{
"id": 2895,
"logprob": -0.765625,
"special": false,
"text": " because"
},
{
"id": 323,
"logprob": -1.0,
"special": false,
"text": " of"
},
{
"id": 290,
"logprob": -0.828125,
"special": false,
"text": " the"
},
{
"id": 2304,
"logprob": -1.015625,
"special": false,
"text": " fact"
},
{
"id": 511,
"logprob": -0.004638672,
"special": false,
"text": " that"
},
{
"id": 2233,
"logprob": -0.953125,
"special": false,
"text": " many"
},
{
"id": 323,
"logprob": -0.87890625,
"special": false,
"text": " of"
},
{
"id": 290,
"logprob": -0.60546875,
"special": false,
"text": " the"
},
{
"id": 6759,
"logprob": -1.6484375,
"special": false,
"text": " extra"
},
{
"id": 40657,
"logprob": -0.00022125244,
"special": false,
"text": " deaths"
},
{
"id": 1610,
"logprob": -0.67578125,
"special": false,
"text": " were"
},
{
"id": 702,
"logprob": -0.30664062,
"special": false,
"text": " not"
},
{
"id": 48692,
"logprob": -0.1953125,
"special": false,
"text": " attributed"
},
{
"id": 328,
"logprob": -0.0079956055,
"special": false,
"text": " to"
},
{
"id": 290,
"logprob": -0.515625,
"special": false,
"text": " the"
},
{
"id": 18938,
"logprob": -0.0040893555,
"special": false,
"text": " flu"
},
{
"id": 26,
"logprob": -0.083496094,
"special": false,
"text": "."
},
{
"id": 13618,
"logprob": -0.515625,
"special": false,
"text": " Many"
},
{
"id": 22215,
"logprob": -1.5703125,
"special": false,
"text": " experts"
},
{
"id": 11081,
"logprob": -0.96875,
"special": false,
"text": " believe"
},
{
"id": 511,
"logprob": -0.1171875,
"special": false,
"text": " that"
},
{
"id": 290,
"logprob": -0.25195312,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -0.828125,
"special": false,
"text": " "
},
{
"id": 7284,
"logprob": -0.00010967255,
"special": false,
"text": "191"
},
{
"id": 36,
"logprob": -8.535385e-05,
"special": false,
"text": "8"
},
{
"id": 18938,
"logprob": -0.056152344,
"special": false,
"text": " flu"
},
{
"id": 27650,
"logprob": -0.0007095337,
"special": false,
"text": " pandemic"
},
{
"id": 26132,
"logprob": -0.18847656,
"special": false,
"text": " killed"
},
{
"id": 1867,
"logprob": -0.71484375,
"special": false,
"text": " between"
},
{
"id": 220,
"logprob": -0.0062561035,
"special": false,
"text": " "
},
{
"id": 1175,
"logprob": -0.009277344,
"special": false,
"text": "50"
},
{
"id": 341,
"logprob": -0.15332031,
"special": false,
"text": " and"
},
{
"id": 220,
"logprob": -8.34465e-07,
"special": false,
"text": " "
},
{
"id": 1135,
"logprob": -0.00065612793,
"special": false,
"text": "100"
},
{
"id": 5534,
"logprob": -1.4066696e-05,
"special": false,
"text": " million"
},
{
"id": 2721,
"logprob": -0.0008392334,
"special": false,
"text": " people"
},
{
"id": 26,
"logprob": -0.54296875,
"special": false,
"text": "."
},
{
"id": 372,
"logprob": -1.8046875,
"special": false,
"text": " I"
},
{
"id": 140680,
"logprob": -0.578125,
"special": false,
"text": "assistant"
},
{
"id": 200006,
"logprob": 0.0,
"special": true,
"text": "<|header_end|>"
},
{
"id": 368,
"logprob": 0.0,
"special": false,
"text": "\n\n"
},
{
"id": 954,
"logprob": -0.032226562,
"special": false,
"text": "The"
},
{
"id": 220,
"logprob": -4.4345856e-05,
"special": false,
"text": " "
},
{
"id": 7284,
"logprob": 0.0,
"special": false,
"text": "191"
},
{
"id": 36,
"logprob": 0.0,
"special": false,
"text": "8"
},
{
"id": 18938,
"logprob": -0.015625,
"special": false,
"text": " flu"
},
{
"id": 27650,
"logprob": 0.0,
"special": false,
"text": " pandemic"
},
{
"id": 24,
"logprob": -0.0072021484,
"special": false,
"text": ","
},
{
"id": 1437,
"logprob": -0.0001707077,
"special": false,
"text": " also"
},
{
"id": 5711,
"logprob": 0.0,
"special": false,
"text": " known"
},
{
"id": 486,
"logprob": 0.0,
"special": false,
"text": " as"
},
{
"id": 290,
"logprob": -5.9604645e-07,
"special": false,
"text": " the"
},
{
"id": 25836,
"logprob": -1.4305115e-06,
"special": false,
"text": " Spanish"
},
{
"id": 18938,
"logprob": -0.0015029907,
"special": false,
"text": " flu"
},
{
"id": 24,
"logprob": -0.0052490234,
"special": false,
"text": ","
},
{
"id": 373,
"logprob": -0.3125,
"special": false,
"text": " is"
},
{
"id": 26078,
"logprob": -0.21289062,
"special": false,
"text": " indeed"
},
{
"id": 1085,
"logprob": -0.080078125,
"special": false,
"text": " one"
},
{
"id": 323,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 290,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 2167,
"logprob": -0.20117188,
"special": false,
"text": " most"
},
{
"id": 92679,
"logprob": -0.12695312,
"special": false,
"text": " devastating"
},
{
"id": 854,
"logprob": -0.25976562,
"special": false,
"text": " public"
},
{
"id": 4500,
"logprob": 0.0,
"special": false,
"text": " health"
},
{
"id": 93079,
"logprob": -0.50390625,
"special": false,
"text": " crises"
},
{
"id": 310,
"logprob": 0.0,
"special": false,
"text": " in"
},
{
"id": 6023,
"logprob": -0.0015182495,
"special": false,
"text": " human"
},
{
"id": 7068,
"logprob": 0.0,
"special": false,
"text": " history"
},
{
"id": 26,
"logprob": -0.0012664795,
"special": false,
"text": "."
},
{
"id": 114059,
"logprob": -0.004119873,
"special": false,
"text": " Estimating"
},
{
"id": 290,
"logprob": -0.00033569336,
"special": false,
"text": " the"
},
{
"id": 6318,
"logprob": -0.20117188,
"special": false,
"text": " exact"
}
],
"top_tokens": null
},
"generated_text": " people died in the 1918 flu pandemic. Estimating the death toll of the 1918 flu pandemic is difficult because of incomplete records and because of the fact that many of the extra deaths were not attributed to the flu. Many experts believe that the 1918 flu pandemic killed between 50 and 100 million people. Iassistant\n\nThe 1918 flu pandemic, also known as the Spanish flu, is indeed one of the most devastating public health crises in human history. Estimating the exact"
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image is a blank white space with no visible objects or features. It appears to be an empty or placeholder image, devoid of any content or visual elements.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1743861910,
"id": "",
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 34,
"prompt_tokens": 166,
"total_tokens": 200
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image is a blank white space with no visible objects or features.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1743861909,
"id": "",
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 15,
"prompt_tokens": 166,
"total_tokens": 181
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "The image is a black background with no discernible objects or features. The image appears to be a blank or empty space, devoid of any visual elements.\n\n**Key Features:**\n\n* **Color:** The dominant color of the image is black.\n* **Objects:** There are no visible objects or shapes in the image.\n* **Background:** The background of the image is a solid black color.\n\n**Conclusion:**\nIn summary, the image is a simple and empty visual representation with a black background and no",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1743861909,
"id": "",
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 166,
"total_tokens": 266
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image shows a brown cow standing on the beach with a white face and black and white marking on its ears. The cow has a white patch around its nose and mouth. The ocean and blue sky are in the background.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1743863057,
"id": "",
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 46,
"prompt_tokens": 164,
"total_tokens": 210
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image does not depict a dog; it shows a cow standing on a beach. Therefore, there is no breed of a dog to identify.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1743863056,
"id": "",
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 168,
"total_tokens": 198
}
}

View File

@ -0,0 +1,155 @@
import base64
from io import BytesIO
from PIL import Image
import pytest
@pytest.fixture(scope="module")
def flash_llama4_handle(launcher):
with launcher("ll-re/Llama-4-Scout-17B-16E-Instruct", num_shard=8) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama4(flash_llama4_handle):
await flash_llama4_handle.health(300)
return flash_llama4_handle.client
async def test_flash_llama4(flash_llama4, response_snapshot):
response = await flash_llama4.generate(
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
seed=42,
max_new_tokens=100,
)
assert (
response.generated_text
== " people died in the 1918 flu pandemic. Estimating the death toll of the 1918 flu pandemic is difficult because of incomplete records and because of the fact that many of the extra deaths were not attributed to the flu. Many experts believe that the 1918 flu pandemic killed between 50 and 100 million people. Iassistant\n\nThe 1918 flu pandemic, also known as the Spanish flu, is indeed one of the most devastating public health crises in human history. Estimating the exact"
)
assert response.details.generated_tokens == 100
assert response == response_snapshot
async def test_flash_llama4_image_cow_dog(flash_llama4, response_snapshot):
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = await flash_llama4.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{
"type": "text",
"text": "What is the breed of the dog in the image?",
},
],
},
],
max_tokens=100,
)
assert (
response.choices[0].message.content
== "The image does not depict a dog; it shows a cow standing on a beach. Therefore, there is no breed of a dog to identify."
)
assert response.usage["completion_tokens"] == 30
assert response == response_snapshot
async def test_flash_llama4_image_cow(flash_llama4, response_snapshot):
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = await flash_llama4.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "What is shown in this image?"},
],
},
],
max_tokens=100,
)
assert (
response.choices[0].message.content
== "The image shows a brown cow standing on the beach with a white face and black and white marking on its ears. The cow has a white patch around its nose and mouth. The ocean and blue sky are in the background."
)
assert response.usage["completion_tokens"] == 46
assert response == response_snapshot
# Helper function to convert a Pillow image to a base64 data URL
def image_to_data_url(img: Image.Image, fmt: str) -> str:
buffer = BytesIO()
img.save(buffer, format=fmt)
img_data = buffer.getvalue()
b64_str = base64.b64encode(img_data).decode("utf-8")
mime_type = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
return f"data:{mime_type};base64,{b64_str}"
async def test_flash_llama4_image_base64_rgba(flash_llama4, response_snapshot):
# Create an empty 100x100 PNG image with alpha (transparent background)
img = Image.new("RGBA", (100, 100), (0, 0, 0, 0))
data_url = image_to_data_url(img, "PNG")
response = await flash_llama4.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{
"type": "text",
"text": "What do you see in this transparent image?",
},
],
},
],
max_tokens=100,
)
assert response == response_snapshot
async def test_flash_llama4_image_base64_rgb_png(flash_llama4, response_snapshot):
# Create an empty 100x100 PNG image without alpha (white background)
img = Image.new("RGB", (100, 100), (255, 255, 255))
data_url = image_to_data_url(img, "PNG")
response = await flash_llama4.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What do you see in this plain image?"},
],
},
],
max_tokens=100,
)
assert response == response_snapshot
async def test_flash_llama4_image_base64_rgb_jpg(flash_llama4, response_snapshot):
# Create an empty 100x100 JPEG image (white background)
img = Image.new("RGB", (100, 100), (255, 255, 255))
data_url = image_to_data_url(img, "JPEG")
response = await flash_llama4.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What do you see in this JPEG image?"},
],
},
],
max_tokens=100,
)
assert response == response_snapshot

View File

@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
@ -103,6 +104,141 @@ impl LlavaNext {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Llama4VisionConfig {
image_size: usize,
patch_size: usize,
pixel_shuffle_ratio: f64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Llama4 {
text_config: TextConfig,
vision_config: Llama4VisionConfig,
}
fn gcd(a: usize, b: usize) -> usize {
if b == 0 {
a
} else {
gcd(b, a % b)
}
}
fn get_factors(dividend: usize) -> HashSet<usize> {
let mut factors_set = HashSet::new();
for i in 1..=((dividend as f64).sqrt() as usize) {
if dividend % i == 0 {
factors_set.insert(i);
factors_set.insert(dividend / i);
}
}
factors_set
}
fn find_supported_resolutions(max_num_chunks: usize, height: usize) -> Vec<(usize, usize)> {
let patch_size = height;
let mut asp_dict: HashMap<(usize, usize), Vec<(usize, usize)>> = HashMap::new();
for chunk_size in (1..=max_num_chunks).rev() {
let mut _factors: Vec<_> = get_factors(chunk_size).into_iter().collect();
_factors.sort();
let _asp_ratios: Vec<(usize, usize)> =
_factors.iter().map(|&f| (f, chunk_size / f)).collect();
for (h, w) in _asp_ratios {
let divisor = gcd(h, w);
let key = (h / divisor, w / divisor); // reduced aspect ratio as key
asp_dict.entry(key).or_default().push((h, w));
}
}
let mut possible_resolutions = vec![];
for (_key, value) in asp_dict {
for (h, w) in value {
possible_resolutions.push((h * patch_size, w * patch_size));
}
}
possible_resolutions
}
fn get_best_fit(
original_height: usize,
original_width: usize,
possible_resolutions: &[(usize, usize)],
resize_to_max_canvas: bool,
) -> (usize, usize) {
let orig_h = original_height as f32;
let orig_w = original_width as f32;
let mut scales = Vec::with_capacity(possible_resolutions.len());
for &(h, w) in possible_resolutions.iter() {
let scale_h = h as f32 / orig_h;
let scale_w = w as f32 / orig_w;
let scale = scale_h.min(scale_w);
scales.push(scale);
}
let upscaling_options: Vec<f32> = scales.iter().copied().filter(|&s| s >= 1.0).collect();
let selected_scale = if !upscaling_options.is_empty() {
if resize_to_max_canvas {
upscaling_options.into_iter().fold(f32::MIN, f32::max)
} else {
upscaling_options.into_iter().fold(f32::MAX, f32::min)
}
} else {
let downscaling_options: Vec<f32> = scales.iter().copied().filter(|&s| s < 1.0).collect();
downscaling_options.into_iter().fold(f32::MIN, f32::max)
};
let chosen_canvas: Vec<(usize, usize)> = possible_resolutions
.iter()
.zip(scales.iter())
.filter(|&(_, &s)| (s - selected_scale).abs() < f32::EPSILON)
.map(|(&(h, w), _)| (h, w))
.collect();
if chosen_canvas.len() > 1 {
chosen_canvas
.into_iter()
.min_by_key(|(h, w)| h * w)
.unwrap()
} else {
chosen_canvas[0]
}
}
impl Llama4 {
pub fn image_size(&self) -> usize {
self.vision_config.image_size
}
pub fn patch_size(&self) -> usize {
self.vision_config.patch_size
}
pub fn pixel_shuffle_ratio(&self) -> f64 {
self.vision_config.pixel_shuffle_ratio
}
pub fn get_aspect_ratios(&self, height: usize, width: usize) -> (usize, usize) {
let patch_size = self.vision_config.image_size;
// How to avoid hardcoding this?
let max_chunks = 15;
let supported = find_supported_resolutions(max_chunks, patch_size);
let (target_h, target_w) = get_best_fit(height, width, &supported, false);
(target_h / patch_size, target_w / patch_size)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct ClipVisionModel { pub struct ClipVisionModel {
@ -258,6 +394,7 @@ pub enum Config {
Phi3, Phi3,
Phimoe, Phimoe,
Llama, Llama,
Llama4(Llama4),
Baichuan, Baichuan,
Paligemma(Paligemma), Paligemma(Paligemma),
Gemma, Gemma,

View File

@ -179,6 +179,7 @@ pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor), Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor), Idefics3Processor(Idefics2Preprocessor),
Gemma3Processor(Gemma3Processor), Gemma3Processor(Gemma3Processor),
Llama4Processor(Llama4Processor),
} }
impl HubPreprocessorConfig { impl HubPreprocessorConfig {
@ -200,6 +201,12 @@ pub struct Gemma3Processor {
do_image_splitting: bool, do_image_splitting: bool,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Llama4Processor {
#[serde(default)]
do_image_splitting: bool,
}
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig { pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,

View File

@ -687,6 +687,47 @@ fn image_tokens(
} }
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)), Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Llama4(config) => {
const IMAGE_START: &str = "<|image_start|>";
const IMAGE: &str = "<|image|>";
const IMAGE_END: &str = "<|image_end|>";
const PATCH: &str = "<|patch|>";
const TILE_X_SEP: &str = "<|tile_x_separator|>";
const TILE_Y_SEP: &str = "<|tile_y_separator|>";
let image_height = config.image_size();
let patch_size = config.patch_size();
let pixel_shuffle_ratio = config.pixel_shuffle_ratio();
let downsample_ratio =
(1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize;
let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width);
let image_width = image_height; // Assuming pixel shape: [H][W][C]
let num_patches_per_chunk =
(image_height / patch_size) * (image_width / patch_size) / downsample_ratio;
let mut img_string = String::new();
img_string.push_str(IMAGE_START);
if ratio_h * ratio_w > 1 {
for _yy in 0..ratio_h {
for xx in 0..ratio_w {
img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
if xx < ratio_w - 1 {
img_string.push_str(TILE_X_SEP);
}
}
img_string.push_str(TILE_Y_SEP);
}
}
img_string.push_str(IMAGE);
img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
img_string.push_str(IMAGE_END);
img_string
}
Qwen2Vl(config) => format!( Qwen2Vl(config) => format!(
"<|vision_start|>{:?}<|vision_end|>", "<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width)) "<|image_pad|>".repeat(config.get_number_of_features(height, width))
@ -730,8 +771,8 @@ fn prepare_input<T: TokenizerTrait>(
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Paligemma(_) config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Llama4(_)
| LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)), | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
) => { ) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -32,7 +32,8 @@ dependencies = [
"tokenizers>=0.20.3", "tokenizers>=0.20.3",
"typer>=0.15.1", "typer>=0.15.1",
"transformers>=4.49.0", "transformers>=4.49.0",
"huggingface-hub>=0.29.0", "huggingface-hub>=0.30.1",
"hf-xet>=1.0.0",
] ]
[build-system] [build-system]

View File

@ -206,7 +206,13 @@ try:
from text_generation_server.models.transformers_flash_causal_lm import ( from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM, TransformersFlashCausalLM,
) )
except ImportError: from text_generation_server.models.transformers_flash_vlm import (
TransformersFlashVlmCausalLM,
TransformersGemma3VlmCausalLM,
TransformersLlama4VlmCausalLM,
)
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
FLASH_TRANSFORMERS_BACKEND = False FLASH_TRANSFORMERS_BACKEND = False
@ -244,6 +250,11 @@ class ModelType(enum.Enum):
"name": "Llama", "name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
} }
LLAMA4 = {
"type": "llama4",
"name": "Llama4",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = { PHI3 = {
"type": "phi3", "type": "phi3",
"name": "Phi 3", "name": "Phi 3",
@ -648,7 +659,6 @@ def get_model(
raise ValueError( raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
head_size = max( head_size = max(
@ -1017,7 +1027,23 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == LLAMA4:
if FLASH_TRANSFORMERS_BACKEND:
from transformers import Llama4ForConditionalGeneration as Llama4Model
return TransformersLlama4VlmCausalLM.fallback(
model_id,
Llama4Model,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
processor_kwargs={
"use_fast": True,
"size": {"height": 336, "width": 336},
},
)
elif model_type == BAICHUAN: elif model_type == BAICHUAN:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
@ -1155,7 +1181,6 @@ def get_model(
) )
elif model_type == GEMMA3: elif model_type == GEMMA3:
if FLASH_ATTENTION: if FLASH_ATTENTION:
# TODO: Use VlmCausalLM when image support is added.
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Gemma3ForConditionalGeneration, model_class=Gemma3ForConditionalGeneration,
@ -1173,12 +1198,15 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND: elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback( from transformers import Gemma3ForConditionalGeneration as Gemma3Model
return TransformersGemma3VlmCausalLM.fallback(
model_id, model_id,
Gemma3Model,
revision, revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -1483,6 +1511,7 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL: if model_type == QWEN2_VL:
if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2VLForConditionalGeneration, model_class=Qwen2VLForConditionalGeneration,
@ -1495,7 +1524,23 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
# TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
# return TransformersQwen2VlmCausalLM.fallback(
# model_id,
# Qwen2VLModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
if model_type == QWEN2_5_VL: if model_type == QWEN2_5_VL:
if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration, model_class=Qwen2_5VLForConditionalGeneration,
@ -1510,6 +1555,21 @@ def get_model(
config_class=Qwen2_5_VLConfig, config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor, processor_class=Qwen2_5_VLProcessor,
) )
# TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
# return TransformersQwen2VlmCausalLM.fallback(
# model_id,
# Qwen2VLModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# config_class=Qwen2_5_VLConfig,
# processor_class=Qwen2_5_VLProcessor,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
if model_type == MLLAMA: if model_type == MLLAMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return MllamaCausalLM( return MllamaCausalLM(
@ -1524,6 +1584,20 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
# TODO: Uncomment when transformers is refactored and cross attn is added
# elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import MllamaForConditionalGeneration as MllamaModel
# return TransformersFlashVlmCausalLM.fallback(
# model_id,
# MllamaModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# batch_class=MllamaCausalLMBatch,
# )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
@ -1542,6 +1616,19 @@ def get_model(
# VRAM usage. # VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
) )
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Idefics2ForConditionalGeneration as Idefics2Model
return TransformersFlashVlmCausalLM.fallback(
model_id,
Idefics2Model,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS3: if model_type == IDEFICS3:
@ -1560,6 +1647,19 @@ def get_model(
# VRAM usage. # VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}}, processor_kwargs={"size": {"longest_edge": 1456}},
) )
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Idefics3ForConditionalGeneration as Idefics3Model
return TransformersFlashVlmCausalLM.fallback(
model_id,
Idefics3Model,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
processor_kwargs={"size": {"longest_edge": 1456}},
)
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA: if model_type == PALIGEMMA:
@ -1578,9 +1678,21 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch, batch_class=PaliGemmaBatch,
) )
else: elif FLASH_TRANSFORMERS_BACKEND:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
return TransformersFlashVlmCausalLM.fallback(
model_id,
PaliGemmaModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
batch_class=PaliGemmaBatch,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
if model_type == LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
@ -1593,6 +1705,18 @@ def get_model(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import LlavaNextForConditionalGeneration as LlavaNextModel
return TransformersFlashVlmCausalLM.fallback(
model_id,
LlavaNextModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

View File

@ -1344,9 +1344,6 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def max_past(self) -> int:
return getattr(self.model, "max_past", None)
def init_kv_cache( def init_kv_cache(
self, self,
num_blocks: int, num_blocks: int,
@ -1792,12 +1789,6 @@ class FlashCausalLM(Model):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs: if sorted_padded_bs:

View File

@ -36,6 +36,7 @@ def tgi_flash_attention_forward(
softcap: Optional[float] = None, softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling **kwargs, # This is needed to "absorb" other args passed by Transformers modeling
): ):
kv_cache = kv_cache[module.layer_idx] kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0) query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0)
@ -72,6 +73,7 @@ def tgi_flash_attention_forward(
max_s, max_s,
kv_scales=kv_scales, kv_scales=kv_scales,
softcap=softcap, softcap=softcap,
window_size_left=sliding_window,
) )
attn_output = attn_output.view(-1, num_heads * head_dim) attn_output = attn_output.view(-1, num_heads * head_dim)
@ -157,7 +159,14 @@ class TransformersFlashCausalLM(FlashCausalLM):
self.num_layers = model.config.num_hidden_layers self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads self.num_heads = model.config.num_attention_heads
self.num_kv_heads = model.config.num_key_value_heads self.num_kv_heads = model.config.num_key_value_heads
self.head_size = model.config.hidden_size // model.config.num_attention_heads # Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(model.config, "head_dim"):
self.head_size = model.config.head_dim
else:
self.head_size = (
model.config.hidden_size // model.config.num_attention_heads
)
# Skip it for models in the exception list # Skip it for models in the exception list
if model.config.model_type not in REPLICATED_ATTENTION_MODELS: if model.config.model_type not in REPLICATED_ATTENTION_MODELS:

View File

@ -0,0 +1,566 @@
import math
from typing import List, Optional
import torch
from opentelemetry import trace
from transformers import AutoTokenizer, AutoProcessor
import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
import torch.nn.functional as F
tracer = trace.get_tracer(__name__)
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
# to internal constraints it was not (yet?) possible to circumvent
REPLICATED_ATTENTION_MODELS = [
"olmo2",
"phi3",
]
# # Qwen2VL
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "tgi"
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "eager"
# ]
def tgi_flash_attention_forward(
module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor,
slots: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
seqlen: Seqlen,
block_tables: torch.Tensor,
max_s: int,
kv_scales: KVScales,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
use_sdpa: Optional[bool] = False,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
):
kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0)
# Take care of updating the cache in-place
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
if cu_seqlen_prefill is not None:
if not use_sdpa:
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=softmax_scale,
window_size_left=sliding_window,
softcap=softcap,
)
else:
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
max_length = max(lengths)
attention_mask = attention_mask[:, :, :, :max_length]
enable_gqa = query_states.shape[1] != key_states.shape[1]
# Split tensors using vectorized split
query_list = torch.split(query_states, lengths.tolist(), dim=0)
key_list = torch.split(key_states, lengths.tolist(), dim=0)
value_list = torch.split(value_states, lengths.tolist(), dim=0)
padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True)
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True)
padded_query = padded_query.transpose(1, 2).contiguous()
padded_key = padded_key.transpose(1, 2).contiguous()
padded_value = padded_value.transpose(1, 2).contiguous()
# Compute attention
attn_output = F.scaled_dot_product_attention(
padded_query,
padded_key,
padded_value,
attn_mask=attention_mask,
scale=softmax_scale,
enable_gqa=enable_gqa,
)
attn_output = attn_output.transpose(
1, 2
) # [batch_size, seq_len, num_heads, head_dim]
max_seq_len = padded_query.size(2)
seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze(
0
)
lengths_tensor = torch.tensor(
lengths, device=padded_query.device
).unsqueeze(1)
mask = seq_range < lengths_tensor # [batch, max_seq_len]
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
else:
attn_output = paged_attention(
query_states,
kv_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=kv_scales,
softcap=softcap,
window_size_left=sliding_window,
)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output, None
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
# TODO: implement
# tgi_cross_attention_forward
class TransformersFlashVlmCausalLM(VlmCausalLM):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
processor_class=AutoProcessor,
processor_kwargs=None,
kv_cache_dtype: Optional[torch.dtype] = None,
batch_class=VlmCausalLMBatch,
):
self.batch_class = batch_class
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
self.dtype = dtype
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
dtype = default_dtype if dtype is None else dtype
else:
raise ValueError(
"Flash `Transformers` modeling backend is not available on cpu."
)
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
attn_implementation = {
"text_config": "tgi",
"vision_config": "sdpa",
}
model = model_class.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
device_map=device if world_size == 1 else None,
tp_plan="auto" if world_size > 1 else None,
)
torch.distributed.barrier(group=self.process_group)
self.config = model.config
config = model.config
# VLM models define the config we care about in their text_config
text_config = getattr(model.config, "text_config", None)
if text_config is not None:
config = text_config
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None and isinstance(
model.config.eos_token_id, int
):
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
# Skip it for models in the exception list
if config.model_type not in REPLICATED_ATTENTION_MODELS:
self.num_heads = self.num_heads // self.process_group.size()
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_prefill_state,
create_decode_state,
create_prefill_with_paged_kv_state,
)
self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.num_groups = self.num_heads // self.num_kv_heads
# Those will never change and will be used in the forwards
self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups)
# This means no scale
self.kv_scales = KVScales(
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device),
)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
# We first copy the original model.forward because we still need it in the monkey patch
self.model.original_forward = self.model.forward
self.model.forward = self._model_forward
self.model.get_position_ids = self.get_position_ids
torch.distributed.barrier(group=self.process_group)
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
return position_ids
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
return {
"input_ids": input_ids.unsqueeze(0),
"position_ids": position_ids.unsqueeze(0),
}
def post_process_outputs(self, logits, lm_head_indices):
return logits.squeeze(dim=0)
@classmethod
def fallback(
cls,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
batch_class: Optional[type] = VlmCausalLMBatch,
processor_kwargs: Optional[dict] = None,
):
return cls(
model_id=model_id,
model_class=model_class,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
batch_class=batch_class,
processor_kwargs=processor_kwargs,
)
def _model_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[KVCache],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
lm_head_indices: Optional[torch.Tensor],
prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
):
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
inputs = self.pre_process_inputs(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
)
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=inputs["input_ids"],
position_ids=inputs["position_ids"],
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
logits_to_keep=logits_to_keep,
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales,
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
attention_mask=inputs.get("attention_mask", None),
use_sdpa=inputs.get("use_sdpa", False),
cache_position=inputs.get("cache_position", None),
).logits
logits = self.post_process_outputs(logits, lm_head_indices)
return logits, None
class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):
if image_grid_thw is None:
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
spatial_merge_size = self.config.vision_config.spatial_merge_size
vision_start_token_id = self.config.vision_start_token_id
vision_end_token_id = self.config.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def post_process_outputs(self, logits, lm_head_indices):
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
input_ids = input_ids.unsqueeze(0)
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
return {"input_ids": input_ids, "position_ids": position_ids}
class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
def get_attention_mask(self, input_ids, cu_seqlen_prefill):
device = input_ids.device
dtype = self.dtype
min_dtype = torch.finfo(dtype).min
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
batch_size = len(lengths)
sequence_length = max(lengths)
target_length = sequence_length
# Create the padding mask from the computed lengths.
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
# Build the base causal mask (for non-image tokens):
causal_mask = torch.tril(
torch.ones(
(sequence_length, sequence_length), dtype=torch.bool, device=device
)
)
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
1
) # [batch, sequence_length, sequence_length]
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
image_token_mask = torch.nn.utils.rnn.pad_sequence(
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
)
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
1
)
# Combine the causal base mask and the bidirectional mask.
combined_mask = torch.logical_or(
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
).to(device)
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
full_attention_mask = torch.zeros(
(batch_size, 1, sequence_length, target_length),
device=device,
dtype=torch.bool,
)
full_attention_mask[:, :, :, :sequence_length] = combined_mask
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
return final_attention_mask
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
inputs = {
"input_ids": input_ids.unsqueeze(0),
"position_ids": position_ids.unsqueeze(0),
}
if cu_seqlen_prefill is not None:
attention_mask = self.get_attention_mask(
input_ids.squeeze(0), cu_seqlen_prefill
)
inputs["attention_mask"] = attention_mask
inputs["use_sdpa"] = True
return inputs
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill)
inputs["cache_position"] = position_ids
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
return inputs

View File

@ -29,6 +29,33 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>" IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
"""
Create a structured string representation of image tokens
Args:
num_patches: Number of patches in the image
Returns:
String with appropriate image tokens
"""
img_string = "<|image_start|>"
ratio_h, ratio_w = aspect_ratio
if ratio_h * ratio_w > 1:
for yy in range(ratio_h):
for xx in range(ratio_w):
img_string += "<|patch|>" * num_patches_per_chunk
if xx < ratio_w - 1:
img_string += "<|tile_x_separator|>"
img_string += "<|tile_y_separator|>"
img_string += "<|image|>"
img_string += "<|patch|>" * num_patches_per_chunk
img_string += "<|image_end|>"
return img_string
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 # copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image( def _prompt_split_image(
*, *,
@ -134,6 +161,23 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_pads = 256 num_pads = 256
padding = "<image_soft_token>" * num_pads padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n" return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
elif config.model_type == "llama4":
patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input["aspect_ratios"][image_id]
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
num_patches_per_chunk = int(
(image_height // patch_size)
* (image_width // patch_size)
// downsample_ratio
)
tokens_for_this_image = prompt_split_image_llama4(
aspect_ratios, num_patches_per_chunk
)
return tokens_for_this_image
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -252,6 +296,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append(image) images.append(image)
elif config.model_type == "gemma3": elif config.model_type == "gemma3":
images.append(image) images.append(image)
elif config.model_type == "llama4":
images.append(image)
else: else:
images.append([image]) images.append([image])
else: else:
@ -285,7 +331,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
processor, image_inputs, config, image_id processor, image_inputs, config, image_id
) )
image_id += 1 image_id += 1
# from pdb import set_trace; set_trace()
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer( input_ids = tokenizer(
full_text, full_text,
@ -372,9 +418,6 @@ class VlmCausalLM(FlashCausalLM):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: VlmCausalLMBatch,
@ -442,12 +485,6 @@ class VlmCausalLM(FlashCausalLM):
) )
batch.position_ids = position_ids batch.position_ids = position_ids
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])

View File

@ -707,9 +707,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240 }, { url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240 },
] ]
[[package]]
name = "hf-xet"
version = "1.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/64/46/db229dddc55121478105940b610fef1b466c414da02be9d4daa5602a2527/hf_xet-1.0.0.tar.gz", hash = "sha256:5e0ca891ce599fd753e7ffbdc182207d952a93e6252eeb92118475d6866bb093", size = 257192 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/0a/c16f8766fa3cd520292b1a765e9b50b8390bce4c2ed7657db9534551f5ed/hf_xet-1.0.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6106304f92bbce7c9b8509f6f735f2e8ce95e4dc32af8876e874c48b15ca1903", size = 5001841 },
{ url = "https://files.pythonhosted.org/packages/e3/9f/cca55edd85d03fc98c743bcc093965740a7440e909779c558039d6838f03/hf_xet-1.0.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:4d0bc7a3e6c1d21fcbb48e8726e3b19a2460e95971375e55e9a5f73ec7079a86", size = 4805318 },
{ url = "https://files.pythonhosted.org/packages/d1/0b/28bda7ac9d699dcfb96f628aa135ddca3f0f77e9716351aab2b83966f957/hf_xet-1.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23dee64f114ea9a272ff71a6a755e025d7a075a6a0dbf6da0990fe9211831ccf", size = 53504907 },
{ url = "https://files.pythonhosted.org/packages/cb/04/ef1f7249a813841d193cbab2ef4d1d7d67c66c61d21d45223a72fdc5c88e/hf_xet-1.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d5f160550508ed87783d1eca4288602a713d1c45ec517d937acb9d93120f0cab", size = 52410434 },
{ url = "https://files.pythonhosted.org/packages/81/b3/e7abec2619ecd9d1c743adfe79fa69cf84530f530969daf3dc804efef65b/hf_xet-1.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5ebd79db87df0b9d3607e7c9a6bb0662c10e36992f522f66b1d2a7fe93f53f27", size = 53465113 },
{ url = "https://files.pythonhosted.org/packages/df/82/b51f3b6e5c6f33e91220c37b17760229704c58e79ab0fcfd0fd3b55803d3/hf_xet-1.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8e6d2625971b4affad634835db82d5392f38de874205a9573e0dd3f0f9cb136f", size = 53461632 },
{ url = "https://files.pythonhosted.org/packages/95/d2/32defba26d995f7acdc4fe3e5911473b25aff5b75c5a2532786435a709e8/hf_xet-1.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:b446964bd75eb7f6b4d983c47241b2023eadfad1f56137ed00e1ca6fc278faec", size = 4121808 },
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.29.1" version = "0.30.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "filelock" }, { name = "filelock" },
@ -720,9 +735,9 @@ dependencies = [
{ name = "tqdm" }, { name = "tqdm" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776 } sdist = { url = "https://files.pythonhosted.org/packages/78/be/049689a7197630e75c4bb53021cb209a56617c9bf39b3a0950650d1f96e1/huggingface_hub-0.30.1.tar.gz", hash = "sha256:f379e8b8d0791295602538856638460ae3cf679c7f304201eb80fb98c771950e", size = 400784 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049 }, { url = "https://files.pythonhosted.org/packages/99/e3/2232d0e726d4d6ea69643b9593d97d0e7e6ea69c2fe9ed5de34d476c1c47/huggingface_hub-0.30.1-py3-none-any.whl", hash = "sha256:0f6aa5ec5a4e68e5b9e45d556b4e5ea180c58f5a5ffa734e7f38c9d573028959", size = 481170 },
] ]
[[package]] [[package]]
@ -2563,6 +2578,7 @@ dependencies = [
{ name = "grpcio-reflection" }, { name = "grpcio-reflection" },
{ name = "grpcio-status" }, { name = "grpcio-status" },
{ name = "hf-transfer" }, { name = "hf-transfer" },
{ name = "hf-xet" },
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
{ name = "kernels" }, { name = "kernels" },
{ name = "loguru" }, { name = "loguru" },
@ -2628,7 +2644,8 @@ requires-dist = [
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" }, { name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" },
{ name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" }, { name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" },
{ name = "hf-transfer", specifier = ">=0.1.8" }, { name = "hf-transfer", specifier = ">=0.1.8" },
{ name = "huggingface-hub", specifier = ">=0.29.0" }, { name = "hf-xet", specifier = ">=1.0.0" },
{ name = "huggingface-hub", specifier = ">=0.30.1" },
{ name = "kernels", specifier = ">=0.2.1" }, { name = "kernels", specifier = ">=0.2.1" },
{ name = "loguru", specifier = ">=0.7.3" }, { name = "loguru", specifier = ">=0.7.3" },
{ name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" }, { name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" },