mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into mi300-compat
This commit is contained in:
commit
f219124711
2
.github/workflows/build.yaml
vendored
2
.github/workflows/build.yaml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_AMI_ID: ami-0789b6925c11b1fb2
|
||||
EC2_INSTANCE_TYPE: g5.12xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
|
@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0
|
||||
ARG PYTHON_VERSION=3.10
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=12.1
|
||||
ARG MAMBA_VERSION=23.3.1-1
|
||||
ARG MAMBA_VERSION=24.3.0-0
|
||||
ARG CUDA_CHANNEL=nvidia
|
||||
ARG INSTALL_CHANNEL=pytorch
|
||||
# Automatically set by buildx
|
||||
@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
ca-certificates \
|
||||
make \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy conda with PyTorch installed
|
||||
|
@ -14,5 +14,10 @@
|
||||
|
||||
__version__ = "0.6.0"
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"`text_generation` clients are deprecated and will be removed in the near future. "
|
||||
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
|
||||
)
|
||||
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||
|
@ -1,10 +1,12 @@
|
||||
import json
|
||||
import requests
|
||||
import warnings
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||
|
||||
from text_generation import DEPRECATION_WARNING
|
||||
from text_generation.types import (
|
||||
StreamResponse,
|
||||
Response,
|
||||
@ -19,6 +21,9 @@ from text_generation.types import (
|
||||
)
|
||||
from text_generation.errors import parse_error
|
||||
|
||||
# emit deprecation warnings
|
||||
warnings.simplefilter("always", DeprecationWarning)
|
||||
|
||||
|
||||
class Client:
|
||||
"""Client to make calls to a text-generation-inference instance
|
||||
@ -59,6 +64,7 @@ class Client:
|
||||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.cookies = cookies
|
||||
@ -449,6 +455,7 @@ class AsyncClient:
|
||||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.cookies = cookies
|
||||
|
@ -9,6 +9,7 @@ The following models are optimized and can be served with TGI, which uses custom
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
|
||||
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||
- [GPT-2](https://huggingface.co/openai-community/gpt2)
|
||||
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [Llama](https://github.com/facebookresearch/llama)
|
||||
- [OPT](https://huggingface.co/facebook/opt-66b)
|
||||
|
BIN
integration-tests/images/cow_beach.png
Normal file
BIN
integration-tests/images/cow_beach.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
@ -0,0 +1,99 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1835938,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.171875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6425781,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.7314453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.005393982,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.31079102,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08300781,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.58984375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.953125,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0957031,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8095703,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9375,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
}
|
@ -0,0 +1,398 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1835938,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.171875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6425781,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.7314453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.005672455,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3251953,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08294678,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5854492,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9423828,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0800781,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8369141,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0683594,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9711914,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1660156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.1796875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6376953,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.72216797,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.0054779053,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3190918,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08319092,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5839844,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8496094,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9370117,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1660156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.1796875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6376953,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.72216797,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.0054779053,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3190918,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08319092,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5839844,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8496094,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9370117,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1660156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.1796875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6376953,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.72216797,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.0054779053,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3190918,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08319092,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5839844,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8496094,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9370117,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
}
|
||||
]
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 54901,
|
||||
"logprob": -0.72753906,
|
||||
"special": false,
|
||||
"text": "beach"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.011009216,
|
||||
"special": true,
|
||||
"text": "<eos>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "beach"
|
||||
}
|
44
integration-tests/models/test_flash_gpt2.py
Normal file
44
integration-tests/models/test_flash_gpt2.py
Normal file
@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gpt2_handle(launcher):
|
||||
with launcher("openai-community/gpt2", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gpt2(flash_gpt2_handle):
|
||||
await flash_gpt2_handle.health(300)
|
||||
return flash_gpt2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
||||
response = await flash_gpt2.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_gpt2,
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert len(generated_texts) == 4
|
||||
assert all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
), generated_texts
|
||||
|
||||
assert responses == response_snapshot
|
39
integration-tests/models/test_flash_pali_gemma.py
Normal file
39
integration-tests/models/test_flash_pali_gemma.py
Normal file
@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
import requests
|
||||
import io
|
||||
import base64
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_pali_gemma_handle(launcher):
|
||||
with launcher(
|
||||
"google/paligemma-3b-pt-224",
|
||||
num_shard=1,
|
||||
revision="float16",
|
||||
max_input_length=4000,
|
||||
max_total_tokens=4096,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||
await flash_pali_gemma_handle.health(300)
|
||||
return flash_pali_gemma_handle.client
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||
cow = get_cow_beach()
|
||||
inputs = f"Where is the cow standing?\n"
|
||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||
|
||||
assert response.generated_text == "beach"
|
||||
assert response == response_snapshot
|
@ -100,7 +100,6 @@ impl LlavaNext {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ClipVisionModel {
|
||||
image_size: usize,
|
||||
@ -108,7 +107,6 @@ pub struct ClipVisionModel {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Idefics2 {}
|
||||
|
||||
@ -118,6 +116,24 @@ impl Idefics2 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct PaliTextConfig {
|
||||
num_image_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Paligemma {
|
||||
text_config: PaliTextConfig,
|
||||
}
|
||||
|
||||
impl Paligemma {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
self.text_config.num_image_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@ -132,6 +148,7 @@ pub enum Config {
|
||||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
Gpt2,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
@ -139,6 +156,7 @@ pub enum Config {
|
||||
Phi3,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Paligemma(Paligemma),
|
||||
Gemma,
|
||||
Cohere,
|
||||
Drbx,
|
||||
|
@ -979,24 +979,28 @@ mod tests {
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -1049,30 +1053,35 @@ mod tests {
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi again!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -1130,24 +1139,28 @@ mod tests {
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -1189,24 +1202,28 @@ mod tests {
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -1234,18 +1251,21 @@ mod tests {
|
||||
content: Some("Hello, how are you?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("I'm doing great. How can I help you today?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("I'd like to show off how chat templating works!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
@ -1257,6 +1277,7 @@ mod tests {
|
||||
),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}]
|
||||
.iter()
|
||||
.chain(&example_chat)
|
||||
@ -1401,12 +1422,14 @@ mod tests {
|
||||
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("How many helicopters can a human eat in one sitting?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
add_generation_prompt: true,
|
||||
|
@ -546,6 +546,7 @@ impl ChatCompletion {
|
||||
content: output,
|
||||
name: None,
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
},
|
||||
logprobs: return_logprobs
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
@ -881,7 +882,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct ToolCall {
|
||||
pub id: u32,
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
@ -954,13 +955,16 @@ pub(crate) struct Message {
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||
#[serde(default, deserialize_with = "message_content_serde::deserialize")]
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
pub name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"get_weather\"")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
|
@ -988,7 +988,6 @@ async fn chat_completions(
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let ChatRequest {
|
||||
logprobs,
|
||||
max_tokens,
|
||||
@ -1160,7 +1159,7 @@ async fn chat_completions(
|
||||
)
|
||||
})?;
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: 0,
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
|
@ -544,6 +544,30 @@ fn prepare_input(
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
Some(Config::Paligemma(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
|
225
server/poetry.lock
generated
225
server/poetry.lock
generated
@ -359,43 +359,45 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "2.14.4"
|
||||
version = "2.19.1"
|
||||
description = "HuggingFace community-driven open-source library of datasets"
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"},
|
||||
{file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"},
|
||||
{file = "datasets-2.19.1-py3-none-any.whl", hash = "sha256:f7a78d15896f45004ccac1c298f3c7121f92f91f6f2bfbd4e4f210f827e6e411"},
|
||||
{file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "*"
|
||||
dill = ">=0.3.0,<0.3.8"
|
||||
fsspec = {version = ">=2021.11.1", extras = ["http"]}
|
||||
huggingface-hub = ">=0.14.0,<1.0.0"
|
||||
dill = ">=0.3.0,<0.3.9"
|
||||
filelock = "*"
|
||||
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
|
||||
huggingface-hub = ">=0.21.2"
|
||||
multiprocess = "*"
|
||||
numpy = ">=1.17"
|
||||
packaging = "*"
|
||||
pandas = "*"
|
||||
pyarrow = ">=8.0.0"
|
||||
pyarrow = ">=12.0.0"
|
||||
pyarrow-hotfix = "*"
|
||||
pyyaml = ">=5.1"
|
||||
requests = ">=2.19.0"
|
||||
tqdm = ">=4.62.1"
|
||||
xxhash = "*"
|
||||
|
||||
[package.extras]
|
||||
apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"]
|
||||
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||
audio = ["librosa", "soundfile (>=0.12.1)"]
|
||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
|
||||
jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"]
|
||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||
quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"]
|
||||
quality = ["ruff (>=0.3.0)"]
|
||||
s3 = ["s3fs"]
|
||||
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
|
||||
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"]
|
||||
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=6.2.1)"]
|
||||
|
||||
@ -418,17 +420,18 @@ dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.3.7"
|
||||
version = "0.3.8"
|
||||
description = "serialize all of Python"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"},
|
||||
{file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"},
|
||||
{file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
|
||||
{file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
graph = ["objgraph (>=1.7.2)"]
|
||||
profile = ["gprof2dot (>=2022.7.29)"]
|
||||
|
||||
[[package]]
|
||||
name = "diskcache"
|
||||
@ -871,13 +874,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.19.4"
|
||||
version = "0.23.0"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"},
|
||||
{file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"},
|
||||
{file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
|
||||
{file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -890,16 +893,17 @@ tqdm = ">=4.42.1"
|
||||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[package.extras]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
cli = ["InquirerPy (==0.3.4)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||
inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"]
|
||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||
inference = ["aiohttp", "minijinja (>=1.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
|
||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["torch"]
|
||||
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["safetensors", "torch"]
|
||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||
|
||||
[[package]]
|
||||
@ -1282,31 +1286,27 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "multiprocess"
|
||||
version = "0.70.15"
|
||||
version = "0.70.16"
|
||||
description = "better multiprocessing and multithreading in Python"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"},
|
||||
{file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"},
|
||||
{file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"},
|
||||
{file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"},
|
||||
{file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"},
|
||||
{file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"},
|
||||
{file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"},
|
||||
{file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"},
|
||||
{file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"},
|
||||
{file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"},
|
||||
{file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"},
|
||||
{file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"},
|
||||
{file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"},
|
||||
{file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"},
|
||||
{file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"},
|
||||
{file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"},
|
||||
{file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
|
||||
{file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
|
||||
{file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
|
||||
{file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
|
||||
{file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
|
||||
{file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
|
||||
{file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
|
||||
{file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
|
||||
{file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
|
||||
{file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
|
||||
{file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
|
||||
{file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
dill = ">=0.3.7"
|
||||
dill = ">=0.3.8"
|
||||
|
||||
[[package]]
|
||||
name = "nest-asyncio"
|
||||
@ -2034,52 +2034,63 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "16.0.0"
|
||||
version = "16.1.0"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:22a1fdb1254e5095d629e29cd1ea98ed04b4bbfd8e42cc670a6b639ccc208b60"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:574a00260a4ed9d118a14770edbd440b848fcae5a3024128be9d0274dbcaf858"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0815d0ddb733b8c1b53a05827a91f1b8bde6240f3b20bf9ba5d650eb9b89cdf"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df0080339387b5d30de31e0a149c0c11a827a10c82f0c67d9afae3981d1aabb7"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:edf38cce0bf0dcf726e074159c60516447e4474904c0033f018c1f33d7dac6c5"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:91d28f9a40f1264eab2af7905a4d95320ac2f287891e9c8b0035f264fe3c3a4b"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:99af421ee451a78884d7faea23816c429e263bd3618b22d38e7992c9ce2a7ad9"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d22d0941e6c7bafddf5f4c0662e46f2075850f1c044bf1a03150dd9e189427ce"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:266ddb7e823f03733c15adc8b5078db2df6980f9aa93d6bb57ece615df4e0ba7"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cc23090224b6594f5a92d26ad47465af47c1d9c079dd4a0061ae39551889efe"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56850a0afe9ef37249d5387355449c0f94d12ff7994af88f16803a26d38f2016"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:705db70d3e2293c2f6f8e84874b5b775f690465798f66e94bb2c07bab0a6bb55"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:5448564754c154997bc09e95a44b81b9e31ae918a86c0fcb35c4aa4922756f55"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:729f7b262aa620c9df8b9967db96c1575e4cfc8c25d078a06968e527b8d6ec05"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:fb8065dbc0d051bf2ae2453af0484d99a43135cadabacf0af588a3be81fbbb9b"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:20ce707d9aa390593ea93218b19d0eadab56390311cb87aad32c9a869b0e958c"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5823275c8addbbb50cd4e6a6839952682a33255b447277e37a6f518d6972f4e1"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab8b9050752b16a8b53fcd9853bf07d8daf19093533e990085168f40c64d978"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:42e56557bc7c5c10d3e42c3b32f6cff649a29d637e8f4e8b311d334cc4326730"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a7abdee4a4a7cfa239e2e8d721224c4b34ffe69a0ca7981354fe03c1328789b"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:ef2f309b68396bcc5a354106741d333494d6a0d3e1951271849787109f0229a6"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ed66e5217b4526fa3585b5e39b0b82f501b88a10d36bd0d2a4d8aa7b5a48e2df"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc8814310486f2a73c661ba8354540f17eef51e1b6dd090b93e3419d3a097b3a"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c2f5e239db7ed43e0ad2baf46a6465f89c824cc703f38ef0fde927d8e0955f7"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f293e92d1db251447cb028ae12f7bc47526e4649c3a9924c8376cab4ad6b98bd"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:dd9334a07b6dc21afe0857aa31842365a62eca664e415a3f9536e3a8bb832c07"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d91073d1e2fef2c121154680e2ba7e35ecf8d4969cc0af1fa6f14a8675858159"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:71d52561cd7aefd22cf52538f262850b0cc9e4ec50af2aaa601da3a16ef48877"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:b93c9a50b965ee0bf4fef65e53b758a7e8dcc0c2d86cebcc037aaaf1b306ecc0"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d831690844706e374c455fba2fb8cfcb7b797bfe53ceda4b54334316e1ac4fa4"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35692ce8ad0b8c666aa60f83950957096d92f2a9d8d7deda93fb835e6053307e"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dd3151d098e56f16a8389c1247137f9e4c22720b01c6f3aa6dec29a99b74d80"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bd40467bdb3cbaf2044ed7a6f7f251c8f941c8b31275aaaf88e746c4f3ca4a7a"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:00a1dcb22ad4ceb8af87f7bd30cc3354788776c417f493089e0a0af981bc8d80"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fda9a7cebd1b1d46c97b511f60f73a5b766a6de4c5236f144f41a5d5afec1f35"},
|
||||
{file = "pyarrow-16.0.0.tar.gz", hash = "sha256:59bb1f1edbbf4114c72415f039f1359f1a57d166a331c3229788ccbfbb31689a"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
|
||||
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.16.6"
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow-hotfix"
|
||||
version = "0.6"
|
||||
description = ""
|
||||
optional = true
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"},
|
||||
{file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.7.1"
|
||||
@ -3016,18 +3027,16 @@ telegram = ["requests"]
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.40.2"
|
||||
version = "4.41.0.dev0"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "transformers-4.40.2-py3-none-any.whl", hash = "sha256:71cb94301ec211a2e1d4b8c8d18dcfaa902dfa00a089dceca167a8aa265d6f2d"},
|
||||
{file = "transformers-4.40.2.tar.gz", hash = "sha256:657b6054a2097671398d976ad46e60836e7e15f9ea9551631a96e33cb9240649"},
|
||||
]
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
huggingface-hub = ">=0.19.3,<1.0"
|
||||
huggingface-hub = ">=0.23.0,<1.0"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
pyyaml = ">=5.1"
|
||||
@ -3040,27 +3049,25 @@ tqdm = ">=4.27"
|
||||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.21.0)"]
|
||||
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
codecarbon = ["codecarbon (==1.2.0)"]
|
||||
deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"]
|
||||
docs-specific = ["hf-doc-builder"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)"]
|
||||
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||
natten = ["natten (>=0.14.6,<0.15.0)"]
|
||||
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
|
||||
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||
optuna = ["optuna"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"]
|
||||
ray = ["ray[tune] (>=2.7.0)"]
|
||||
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
|
||||
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||
@ -3069,19 +3076,25 @@ serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
timm = ["timm"]
|
||||
tokenizers = ["tokenizers (>=0.19,<0.20)"]
|
||||
torch = ["accelerate (>=0.21.0)", "torch"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib_metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"]
|
||||
video = ["av (==9.2.0)", "decord (==0.6.0)"]
|
||||
vision = ["Pillow (>=10.0.1,<=15.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/huggingface/transformers.git"
|
||||
reference = "b8aee2e"
|
||||
resolved_reference = "b8aee2e918d7ba2d5e9e80162ae26b4806873307"
|
||||
|
||||
[[package]]
|
||||
name = "triton"
|
||||
version = "2.3.0"
|
||||
@ -3488,4 +3501,4 @@ torch = ["torch"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "df83b265d0263870b5d1ae8bfd847f406abef90868fdf528ff38527b512f86c0"
|
||||
content-hash = "b2a29b0b6e32d0e7043e94b984c5731f2c27c5d95feccbeb80bd890db22d6c4a"
|
||||
|
@ -25,8 +25,9 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.19.1"
|
||||
huggingface-hub = "^0.19.3"
|
||||
transformers = "^4.40"
|
||||
huggingface-hub = "^0.23"
|
||||
# transformers = "^4.40"
|
||||
transformers = { git = "https://github.com/huggingface/transformers.git", rev="b8aee2e" }
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
|
@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -70,7 +70,7 @@ class Linear8bitLt(torch.nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
class Linear4bit(torch.nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
self.weight = Params4bit(
|
||||
|
@ -15,9 +15,9 @@ class FastLinear(torch.nn.Module):
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
if bias is not None:
|
||||
self.bias = torch.nn.Parameter(bias)
|
||||
self.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -51,6 +51,7 @@ FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
@ -64,6 +65,9 @@ try:
|
||||
from text_generation_server.models.flash_gemma import (
|
||||
FlashGemma,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemma,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
@ -83,6 +87,7 @@ except ImportError as e:
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashGPT2)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
@ -325,7 +330,27 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "gpt2":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGPT2(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "gpt_neox":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
@ -654,6 +679,18 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == "paligemma":
|
||||
if FLASH_ATTENTION:
|
||||
return PaliGemma(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -99,8 +99,13 @@ class GemmaConfig(PretrainedConfig):
|
||||
class GemmaFastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
return cls(weight, eps)
|
||||
weights.dtype = dtype
|
||||
new = cls(weight, eps)
|
||||
new.dtype = dtype
|
||||
return new
|
||||
|
||||
# perform the multiplication in full precision and downcast after
|
||||
def forward(self, hidden_states, residual=None):
|
||||
@ -111,7 +116,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states.to(self.weight.dtype), residual
|
||||
return hidden_states.to(self.dtype), residual
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
@ -153,15 +158,11 @@ def _load_gqa(config, prefix: str, weights):
|
||||
|
||||
|
||||
class FlashGemmaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.head_dim
|
||||
self.causal = causal
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
@ -238,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -295,11 +297,10 @@ class GemmaMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashGemmaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashGemmaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||
)
|
||||
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -351,30 +352,25 @@ class FlashGemmaLayer(nn.Module):
|
||||
|
||||
|
||||
class FlashGemmaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
embed_norm = config.hidden_size**0.5
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemmaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
causal=causal,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = GemmaFastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -385,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -394,7 +390,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
@ -423,13 +419,30 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashGemmaModel(config, weights)
|
||||
embed_norm = config.hidden_size**0.5
|
||||
if prefix is None:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.model = FlashGemmaModel(
|
||||
prefix=prefix, config=config, weights=weights, causal=causal
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||
prefix=(
|
||||
f"{prefix}.embed_tokens"
|
||||
if config.tie_word_embeddings
|
||||
else f"{prefix}.lm_head"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
@ -445,8 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
input_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -0,0 +1,454 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
if config.quantize == "gptq":
|
||||
return _load_qkv_gptq(
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
)
|
||||
else:
|
||||
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
||||
|
||||
|
||||
def _load_qkv_gptq(config, prefix: str, weights):
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
# Weights
|
||||
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
|
||||
|
||||
# Bias
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
shape = slice_.get_shape()
|
||||
total_size = shape[0]
|
||||
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
|
||||
single_size = total_size // 3
|
||||
assert single_size % world_size == 0
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensors = []
|
||||
for i in range(3):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
tensors.append(tensor)
|
||||
bias = torch.cat(tensors, dim=0)
|
||||
bias = bias.to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
"""Load QKV from a single, transposed matrix."""
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||
shape = slice_.get_shape()
|
||||
total_size = shape[1]
|
||||
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
|
||||
world_size = weights.process_group.size()
|
||||
single_size = total_size // 3
|
||||
assert single_size % world_size == 0
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
# Weights
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensors = []
|
||||
for i in range(3):
|
||||
tensor = slice_[:, start + i * single_size : stop + i * single_size]
|
||||
tensors.append(tensor)
|
||||
weight = torch.cat(tensors, dim=1).T
|
||||
weight = weight.to(dtype=weights.dtype)
|
||||
weight = weight.to(device=weights.device)
|
||||
|
||||
# Bias
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
shape = slice_.get_shape()
|
||||
total_size = shape[0]
|
||||
single_size = total_size // 3
|
||||
block_size = single_size // world_size
|
||||
assert single_size % world_size == 0
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
b = []
|
||||
for i in range(3):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
b.append(tensor)
|
||||
bias = torch.cat(b, dim=0)
|
||||
bias = bias.to(dtype=weights.dtype)
|
||||
bias = bias.to(device=weights.device)
|
||||
assert list(bias.shape) == [
|
||||
3 * num_heads * head_size
|
||||
], f"{weight.shape} != {[3 * num_heads * head_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
"""load_row, but with transposed weight matrices."""
|
||||
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
def load_col(config, prefix: str, weights, bias: bool):
|
||||
"""load_col, but with transposed weight matrices."""
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=1
|
||||
)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
class FlashGPT2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.query_key_value = load_qkv(
|
||||
config,
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
self.o_proj = load_row(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
self.head_size * self.num_heads, dim=1
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.activation_function
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.c_fc = load_col(
|
||||
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
|
||||
)
|
||||
self.c_proj = load_row(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
intermediate_size = (
|
||||
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
return self.c_proj(hidden_states)
|
||||
|
||||
|
||||
class FlashGPT2Layer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGPT2Attention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.ln_2",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
||||
return residual + mlp_output, residual
|
||||
|
||||
|
||||
class FlashGPT2Model(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGPT2Layer(
|
||||
prefix=(
|
||||
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = nn.LayerNorm.load(
|
||||
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
||||
weights=weights,
|
||||
)
|
||||
self.embed_positions = TensorParallelEmbedding(
|
||||
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.model = FlashGPT2Model(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="wte" if not prefix else f"{prefix}.wte",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
inputs_embeds = token_embeds + position_embeds
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix="multi_modal_projector.linear",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
text_config = config.text_config
|
||||
text_config.speculator = config.speculator
|
||||
text_config.quantize = config.quantize
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused here
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
|
||||
# mask where image or padding tokens
|
||||
mask = input_ids == self.config.image_token_index
|
||||
|
||||
# insert image features into input embeddings
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
)
|
||||
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
|
||||
return logits, speculative_logits
|
565
server/text_generation_server/models/custom_modeling/siglip.py
Normal file
565
server/text_generation_server/models/custom_modeling/siglip.py
Normal file
@ -0,0 +1,565 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_create_4d_causal_attention_mask,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
|
||||
from text_generation_server.layers.tensor_parallel import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
self.patch_embedding.bias = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||
)
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values
|
||||
) # shape = [*, width, grid, grid]
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: SiglipTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, embed_dim
|
||||
)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = (
|
||||
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.head_size = self.head_dim
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
|
||||
)
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
# scale post matmul
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(attn_weights.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
|
||||
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
|
||||
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config: SiglipConfig, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = SiglipAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||
attention_mask (`torch.FloatTensor`):
|
||||
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if output_attentions:
|
||||
return hidden_states, attn_weights
|
||||
return hidden_states, None
|
||||
|
||||
|
||||
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||
super().__init__()
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
self.attention = torch.nn.MultiheadAttention(
|
||||
config.hidden_size, config.num_attention_heads, batch_first=True
|
||||
)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(prefix, config, weights)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
batch_size = hidden_state.shape[0]
|
||||
probe = self.probe.repeat(batch_size, 1, 1)
|
||||
|
||||
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||
|
||||
residual = hidden_state
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
hidden_state = residual + self.mlp(hidden_state)
|
||||
|
||||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.0))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
|
||||
|
||||
def trunc_normal_tf_(
|
||||
tensor: torch.Tensor,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
a: float = -2.0,
|
||||
b: float = 2.0,
|
||||
) -> torch.Tensor:
|
||||
"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \\leq \text{mean} \\leq b`.
|
||||
|
||||
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
||||
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
||||
and the result is subsquently scaled and shifted by the mean and std args.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
"""
|
||||
with torch.no_grad():
|
||||
_trunc_normal_(tensor, 0, 1.0, a, b)
|
||||
tensor.mul_(std).add_(mean)
|
||||
|
||||
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == "fan_in":
|
||||
denom = fan_in
|
||||
elif mode == "fan_out":
|
||||
denom = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = scale / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
||||
elif distribution == "normal":
|
||||
with torch.no_grad():
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
with torch.no_grad():
|
||||
tensor.uniform_(-bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
||||
|
||||
|
||||
def default_flax_embed_init(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
||||
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class SiglipPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = SiglipConfig
|
||||
base_model_prefix = "siglip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, SiglipVisionEmbeddings):
|
||||
width = (
|
||||
self.config.vision_config.hidden_size
|
||||
if isinstance(self.config, SiglipConfig)
|
||||
else self.config.hidden_size
|
||||
)
|
||||
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
||||
elif isinstance(module, nn.Embedding):
|
||||
default_flax_embed_init(module.weight)
|
||||
elif isinstance(module, SiglipAttention):
|
||||
nn.init.xavier_uniform_(module.q_proj.weight)
|
||||
nn.init.xavier_uniform_(module.k_proj.weight)
|
||||
nn.init.xavier_uniform_(module.v_proj.weight)
|
||||
nn.init.xavier_uniform_(module.out_proj.weight)
|
||||
nn.init.zeros_(module.q_proj.bias)
|
||||
nn.init.zeros_(module.k_proj.bias)
|
||||
nn.init.zeros_(module.v_proj.bias)
|
||||
nn.init.zeros_(module.out_proj.bias)
|
||||
elif isinstance(module, SiglipMLP):
|
||||
nn.init.xavier_uniform_(module.fc1.weight)
|
||||
nn.init.xavier_uniform_(module.fc2.weight)
|
||||
nn.init.normal_(module.fc1.bias, std=1e-6)
|
||||
nn.init.normal_(module.fc2.bias, std=1e-6)
|
||||
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
||||
nn.init.xavier_uniform_(module.probe.data)
|
||||
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
|
||||
nn.init.zeros_(module.attention.in_proj_bias.data)
|
||||
elif isinstance(module, SiglipModel):
|
||||
logit_scale_init = torch.log(torch.tensor(1.0))
|
||||
module.logit_scale.data.fill_(logit_scale_init)
|
||||
module.logit_bias.data.zero_()
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class SiglipEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`SiglipEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: SiglipConfig
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: SiglipConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SiglipEncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states, _ = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipVisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.encoder = SiglipEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
# NOTE: up until this point, the code logits are exactly
|
||||
# the same as the transformers code. The values evaulate
|
||||
# slightly differently in our encoder layer.
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
last_hidden_state = encoder_outputs
|
||||
post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=post_last_hidden_state,
|
||||
# pooler_output=pooled_output,
|
||||
# hidden_states=encoder_outputs,
|
||||
)
|
@ -11,6 +11,18 @@ def load_text_model(prefix, config, weights, name=None):
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
||||
elif config.model_type == "gemma":
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
||||
elif config.model_type == "paligemma":
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashGemmaForCausalLM(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
@ -24,5 +36,13 @@ def load_vision_model(prefix, config, weights):
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
if config.model_type == "siglip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.siglip import (
|
||||
SiglipVisionTransformer,
|
||||
)
|
||||
|
||||
return SiglipVisionTransformer(
|
||||
prefix=f"vision_tower.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
@ -135,6 +135,17 @@ class FlashCausalLMBatch(Batch):
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
batch_tokenized_inputs,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
@ -209,6 +220,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
speculative_length = get_speculate()
|
||||
speculative_length = 0 if speculative_length is None else speculative_length
|
||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
blocks += needed_blocks
|
||||
@ -760,7 +772,7 @@ class FlashCausalLM(Model):
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
empty_cache()
|
||||
|
||||
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
batch.blocks,
|
||||
|
@ -3,12 +3,11 @@ import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers.models.gemma import GemmaTokenizerFast
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
GemmaConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
@ -38,17 +37,15 @@ class FlashGemma(FlashCausalLM):
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = GemmaConfig.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
@ -61,7 +58,9 @@ class FlashGemma(FlashCausalLM):
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGemmaForCausalLM(config, weights)
|
||||
# TODO hardcoded
|
||||
prefix = "language_model"
|
||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
|
78
server/text_generation_server/models/flash_gpt2.py
Normal file
78
server/text_generation_server/models/flash_gpt2.py
Normal file
@ -0,0 +1,78 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from transformers.models.gpt2 import GPT2Tokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class FlashGPT2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGPT2, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
123
server/text_generation_server/models/pali_gemma.py
Normal file
123
server/text_generation_server/models/pali_gemma.py
Normal file
@ -0,0 +1,123 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from typing import Optional, Tuple
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLM,
|
||||
VlmCausalLMBatch,
|
||||
image_text_replacement,
|
||||
load_data_uri,
|
||||
split,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
image_id = 0
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += "<bos>" + chunk["content"] + "\n"
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
# On the rust layer.
|
||||
# This avoid making n queries per TP
|
||||
# if image.startswith("https://") or image.startswith("http://"):
|
||||
# image = processor.image_processor.fetch_images(image)
|
||||
if image.startswith("data:"):
|
||||
image = load_data_uri(image)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
# TODO do_convert_RGB should be on by default ?
|
||||
image = image.convert("RGB")
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_special_tokens=False,
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_input = image_inputs[0]
|
||||
new_image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
}
|
||||
if "pixel_attention_mask" in image_input:
|
||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||
)
|
||||
if "image_sizes" in image_input:
|
||||
new_image_inputs["image_sizes"] = torch.cat(
|
||||
[img["image_sizes"] for img in image_inputs], dim=0
|
||||
)
|
||||
image_inputs = new_image_inputs
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
|
||||
class PaliGemma(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config_cls=AutoConfig,
|
||||
model_cls=PaliGemmaForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self):
|
||||
return PaliGemmaBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.text_model.model.layers),
|
||||
model.text_model.model.num_key_value_heads,
|
||||
model.text_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
@ -15,6 +15,7 @@ from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistralBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
)
|
||||
@ -80,6 +81,9 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
||||
|
||||
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||
return "<image>" * num_features
|
||||
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -193,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
batch_inputs,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_special_tokens=not config.model_type == "paligemma",
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_input = image_inputs[0]
|
||||
|
@ -14,7 +14,10 @@ from typing import List, Optional
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
@ -98,6 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
PaliGemmaBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
@ -122,6 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
PaliGemmaBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
|
@ -131,6 +131,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
@ -149,7 +150,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
False,
|
||||
|
Loading…
Reference in New Issue
Block a user