Merge branch 'huggingface:main' into main

This commit is contained in:
Sabidao 2024-04-21 17:53:17 +03:00 committed by GitHub
commit 3116fb5113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 65442 additions and 2243 deletions

View File

@ -13,7 +13,10 @@ jobs:
- name: Install Launcher - name: Install Launcher
id: install-launcher id: install-launcher
run: cargo install --git https://github.com/${{ github.repository }} --branch ${{ github.head_ref }} text-generation-launcher env:
REF: ${{ github.head_ref }}
REPO: ${{ github.repository }}
run: cargo install --git "https://github.com/$REPO" --branch "$REF" text-generation-launcher
- name: Check launcher Docs are up-to-date - name: Check launcher Docs are up-to-date
run: | run: |

913
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,13 +9,19 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "1.4.5" version = "2.0.1"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
[profile.release] [profile.release]
debug = 1 debug = 1
incremental = true incremental = true
lto = "off" lto = "fat"
opt-level = 3
codegen-units = 1
panic = "abort" panic = "abort"

View File

@ -85,7 +85,7 @@ FROM pytorch-install as kernel-builder
ARG MAX_JOBS=8 ARG MAX_JOBS=8
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build \ ninja-build cmake \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Build Flash Attention CUDA kernels # Build Flash Attention CUDA kernels
@ -160,11 +160,6 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile COPY server/Makefile-selective-scan Makefile
RUN make build-all RUN make build-all
# Build megablocks
FROM kernel-builder as megablocks-builder
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
@ -186,8 +181,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \ curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch and Megablocks installed # Copy conda with PyTorch installed
COPY --from=megablocks-builder /opt/conda /opt/conda COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder # Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@ -215,7 +210,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies # Install vllm/flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
# Install server # Install server
@ -250,5 +245,7 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image # Final image
FROM base FROM base
ENTRYPOINT ["text-generation-launcher"] COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -76,7 +76,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
model=HuggingFaceH4/zephyr-7b-beta model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -90,7 +90,7 @@ curl 127.0.0.1:8080/generate_stream \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -120,7 +120,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)

View File

@ -23,9 +23,9 @@ serde_json = "1.0"
tabled = "0.14.0" tabled = "0.14.0"
text-generation-client = { path = "../router/client" } text-generation-client = { path = "../router/client" }
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] } tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]} tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = "0.3.1" hf-hub = { workspace = true }

View File

@ -9,6 +9,11 @@ def flan_t5_xxl():
return "google/flan-t5-xxl" return "google/flan-t5-xxl"
@pytest.fixture
def llama_7b():
return "meta-llama/Llama-2-7b-chat-hf"
@pytest.fixture @pytest.fixture
def fake_model(): def fake_model():
return "fake/model" return "fake/model"
@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}" return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def llama_7b_url(base_url, llama_7b):
return f"{base_url}/{llama_7b}"
@pytest.fixture @pytest.fixture
def fake_url(base_url, fake_model): def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}" return f"{base_url}/{fake_model}"

View File

@ -5,24 +5,24 @@ from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, InputToken from text_generation.types import FinishReason, InputToken
def test_generate(flan_t5_xxl_url, hf_headers): def test_generate(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
response = client.generate("test", max_new_tokens=1, decoder_input_details=True) response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == " " assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers): def test_generate_best_of(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
response = client.generate( response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
) )
@ -39,14 +39,14 @@ def test_generate_not_found(fake_url, hf_headers):
client.generate("test") client.generate("test")
def test_generate_validation_error(flan_t5_xxl_url, hf_headers): def test_generate_validation_error(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000) client.generate("test", max_new_tokens=10_000)
def test_generate_stream(flan_t5_xxl_url, hf_headers): def test_generate_stream(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
responses = [ responses = [
response for response in client.generate_stream("test", max_new_tokens=1) response for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -54,7 +54,7 @@ def test_generate_stream(flan_t5_xxl_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
list(client.generate_stream("test")) list(client.generate_stream("test"))
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): def test_generate_stream_validation_error(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000)) list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate( response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True "test", max_new_tokens=1, decoder_input_details=True
) )
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert response.details.prefill[1] == InputToken(
id=1243, text="test", logprob=-10.96875
)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == " " assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate( response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
) )
@ -112,15 +115,15 @@ async def test_generate_async_not_found(fake_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): async def test_generate_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000) await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): async def test_generate_stream_async(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
responses = [ responses = [
response async for response in client.generate_stream("test", max_new_tokens=1) response async for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -128,7 +131,7 @@ async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers): async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000): async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass pass

View File

@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
usage: Optional[Any] = None usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel): class Function(BaseModel):
name: Optional[str] name: Optional[str]
arguments: str arguments: str
@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
usage: Any usage: Any
class Completion(BaseModel):
# Completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[CompletionComplete]
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
# Model identifier # Model identifier
model: str model: str

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.4.5" "version": "2.0.1"
}, },
"paths": { "paths": {
"/": { "/": {
@ -408,9 +408,14 @@
}, },
"responses": { "responses": {
"200": { "200": {
"description": "Generated Text", "description": "Generated Chat Completion",
"content": { "content": {
"application/json": { "application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletion"
}
},
"text/event-stream": {
"schema": { "schema": {
"$ref": "#/components/schemas/ChatCompletionChunk" "$ref": "#/components/schemas/ChatCompletionChunk"
} }
@ -492,11 +497,16 @@
}, },
"responses": { "responses": {
"200": { "200": {
"description": "Generated Text", "description": "Generated Chat Completion",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/ChatCompletionChunk" "$ref": "#/components/schemas/Completion"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/CompletionCompleteChunk"
} }
} }
} }
@ -930,7 +940,7 @@
"tool_prompt": { "tool_prompt": {
"type": "string", "type": "string",
"description": "A prompt to be appended before the tools", "description": "A prompt to be appended before the tools",
"example": "\"Based on the conversation, please choose the most appropriate tool to use: \"", "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
"nullable": true "nullable": true
}, },
"tools": { "tools": {
@ -1071,7 +1081,10 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"prompt": { "prompt": {
"type": "string", "type": "array",
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.", "description": "The prompt to generate completions for.",
"example": "What is Deep Learning?" "example": "What is Deep Learning?"
}, },
@ -1234,17 +1247,17 @@
"type": "object", "type": "object",
"required": [ "required": [
"name", "name",
"parameters" "arguments"
], ],
"properties": { "properties": {
"arguments": {},
"description": { "description": {
"type": "string", "type": "string",
"nullable": true "nullable": true
}, },
"name": { "name": {
"type": "string" "type": "string"
}, }
"parameters": {}
} }
}, },
"GenerateParameters": { "GenerateParameters": {
@ -1260,7 +1273,7 @@
}, },
"decoder_input_details": { "decoder_input_details": {
"type": "boolean", "type": "boolean",
"default": "true" "default": "false"
}, },
"details": { "details": {
"type": "boolean", "type": "boolean",
@ -1285,6 +1298,7 @@
"$ref": "#/components/schemas/GrammarType" "$ref": "#/components/schemas/GrammarType"
} }
], ],
"default": "null",
"nullable": true "nullable": true
}, },
"max_new_tokens": { "max_new_tokens": {
@ -1478,6 +1492,7 @@
"max_batch_total_tokens", "max_batch_total_tokens",
"max_waiting_tokens", "max_waiting_tokens",
"validation_workers", "validation_workers",
"max_client_batch_size",
"version" "version"
], ],
"properties": { "properties": {
@ -1503,6 +1518,11 @@
"example": "2", "example": "2",
"minimum": 0 "minimum": 0
}, },
"max_client_batch_size": {
"type": "integer",
"example": "32",
"minimum": 0
},
"max_concurrent_requests": { "max_concurrent_requests": {
"type": "integer", "type": "integer",
"description": "Router Parameters", "description": "Router Parameters",

View File

@ -60,12 +60,13 @@ Options:
[env: QUANTIZE=] [env: QUANTIZE=]
Possible values: Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency - awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
``` ```
## SPECULATE ## SPECULATE
@ -128,23 +129,29 @@ Options:
[env: MAX_TOP_N_TOKENS=] [env: MAX_TOP_N_TOKENS=]
[default: 5] [default: 5]
```
## MAX_INPUT_TOKENS
```shell
--max-input-tokens <MAX_INPUT_TOKENS>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
[env: MAX_INPUT_TOKENS=]
``` ```
## MAX_INPUT_LENGTH ## MAX_INPUT_LENGTH
```shell ```shell
--max-input-length <MAX_INPUT_LENGTH> --max-input-length <MAX_INPUT_LENGTH>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle Legacy version of [`Args::max_input_tokens`]
[env: MAX_INPUT_LENGTH=] [env: MAX_INPUT_LENGTH=]
[default: 1024]
``` ```
## MAX_TOTAL_TOKENS ## MAX_TOTAL_TOKENS
```shell ```shell
--max-total-tokens <MAX_TOTAL_TOKENS> --max-total-tokens <MAX_TOTAL_TOKENS>
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
[env: MAX_TOTAL_TOKENS=] [env: MAX_TOTAL_TOKENS=]
[default: 2048]
``` ```
## WAITING_SERVED_RATIO ## WAITING_SERVED_RATIO
@ -161,10 +168,9 @@ Options:
## MAX_BATCH_PREFILL_TOKENS ## MAX_BATCH_PREFILL_TOKENS
```shell ```shell
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS> --max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room
[env: MAX_BATCH_PREFILL_TOKENS=] [env: MAX_BATCH_PREFILL_TOKENS=]
[default: 4096]
``` ```
## MAX_BATCH_TOTAL_TOKENS ## MAX_BATCH_TOTAL_TOKENS
@ -209,10 +215,9 @@ Options:
## CUDA_GRAPHS ## CUDA_GRAPHS
```shell ```shell
--cuda-graphs <CUDA_GRAPHS> --cuda-graphs <CUDA_GRAPHS>
Specify the batch sizes to compute cuda graphs for. Use "0" to disable Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32"
[env: CUDA_GRAPHS=] [env: CUDA_GRAPHS=]
[default: 1,2,4,8,16,32,64,96,128]
``` ```
## HOSTNAME ## HOSTNAME
@ -393,6 +398,15 @@ Options:
-e, --env -e, --env
Display a lot of information about your runtime environment Display a lot of information about your runtime environment
```
## MAX_CLIENT_BATCH_SIZE
```shell
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
Control the maximum number of inputs that a client can send in a single request
[env: MAX_CLIENT_BATCH_SIZE=]
[default: 4]
``` ```
## HELP ## HELP
```shell ```shell

View File

@ -74,7 +74,7 @@ curl localhost:3000/generate \
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar. A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. > Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
### Constrain with Pydantic ### Constrain with Pydantic

View File

@ -22,6 +22,8 @@ The following models are optimized and can be served with TGI, which uses custom
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) - [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Phi](https://huggingface.co/microsoft/phi-2) - [Phi](https://huggingface.co/microsoft/phi-2)
- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal)
- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:

View File

@ -9,6 +9,7 @@ import json
import math import math
import time import time
import random import random
import re
from docker.errors import NotFound from docker.errors import NotFound
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -26,6 +27,7 @@ from text_generation.types import (
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
Completion,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -69,17 +71,22 @@ class ResponseComparator(JSONSnapshotExtension):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data: if isinstance(data, Dict) and "choices" in data:
choices = data["choices"] choices = data["choices"]
if ( if isinstance(choices, List) and len(choices) >= 1:
isinstance(choices, List) if "delta" in choices[0]:
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data) return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data) return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
if isinstance(data, List): if isinstance(data, List):
if (
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data] return [Response(**d) for d in data]
raise NotImplementedError raise NotImplementedError
@ -161,6 +168,9 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_completion(response: Completion, other: Completion) -> bool:
return response.choices[0].text == other.choices[0].text
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return ( return (
response.choices[0].message.content == other.choices[0].message.content response.choices[0].message.content == other.choices[0].message.content
@ -184,6 +194,11 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List): if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data] snapshot_data = [snapshot_data]
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatComplete): if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
@ -277,6 +292,8 @@ def launcher(event_loop):
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -314,6 +331,12 @@ def launcher(event_loop):
args.append(revision) args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
@ -347,6 +370,8 @@ def launcher(event_loop):
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -367,6 +392,12 @@ def launcher(event_loop):
args.append(revision) args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
client = docker.from_env() client = docker.from_env()

View File

@ -13,11 +13,11 @@
"usage": null "usage": null
} }
], ],
"created": 1710795556, "created": 1712874856,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 100, "completion_tokens": 100,
"prompt_tokens": 60, "prompt_tokens": 60,

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 1,
"logprobs": null,
"text": " PR for more information?"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
},
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": " severely flawed and often has a substandard"
},
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": "hd20220811-"
}
],
"created": 1713284455,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 8,
"total_tokens": 44
}
}

View File

@ -0,0 +1,602 @@
[
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "hd"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "aho"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " is"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "m"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Room"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " the"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " tired"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": ":"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " capital"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " of"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " She"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " scale"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " of"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " being"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
}
]

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " PR for flake8"
}
],
"created": 1713284454,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 5,
"prompt_tokens": 6,
"total_tokens": 11
}
}

View File

@ -0,0 +1,65 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 6,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -10.5,
"text": "Test"
},
{
"id": 2159,
"logprob": -12.140625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.0654297,
"special": false,
"text": "\n"
},
{
"id": 1014,
"logprob": -2.7460938,
"special": false,
"text": "The"
},
{
"id": 6032,
"logprob": -1.359375,
"special": false,
"text": " purpose"
},
{
"id": 302,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 456,
"logprob": 0.0,
"special": false,
"text": " this"
},
{
"id": 1369,
"logprob": -0.40063477,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "Test request\nThe purpose of this test"
}

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.00756073,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.20117188,
"special": false,
"text": "\n"
},
{
"id": 16114,
"logprob": -1.2597656,
"special": false,
"text": "Once"
},
{
"id": 3714,
"logprob": -0.20825195,
"special": false,
"text": " upon"
},
{
"id": 264,
"logprob": -0.00178051,
"special": false,
"text": " a"
},
{
"id": 727,
"logprob": -0.011955261,
"special": false,
"text": " time"
},
{
"id": 28725,
"logprob": -0.17541504,
"special": false,
"text": ","
},
{
"id": 736,
"logprob": -0.91308594,
"special": false,
"text": " there"
},
{
"id": 403,
"logprob": -0.058410645,
"special": false,
"text": " was"
},
{
"id": 264,
"logprob": -0.009689331,
"special": false,
"text": " a"
}
],
"top_tokens": null
},
"generated_text": "\n\nOnce upon a time, there was a"
}

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "eos_token", "finish_reason": "length",
"generated_tokens": 9, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 0, "id": 0,
@ -14,7 +14,7 @@
"tokens": [ "tokens": [
{ {
"id": 16017, "id": 16017,
"logprob": -0.30908203, "logprob": 0.0,
"special": false, "special": false,
"text": " blue" "text": " blue"
}, },
@ -26,39 +26,45 @@
}, },
{ {
"id": 259, "id": 259,
"logprob": -0.28271484, "logprob": -0.4716797,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 15484, "id": 261,
"logprob": -1.7929688, "logprob": -0.044677734,
"special": false, "special": false,
"text": "appear" "text": ","
}, },
{ {
"id": 345, "id": 35622,
"logprob": -0.8935547, "logprob": -0.79589844,
"special": false, "special": false,
"text": "ed" "text": " cloud"
}, },
{ {
"id": 281, "id": 263,
"logprob": -1.2958984,
"special": false,
"text": "s"
},
{
"id": 305,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " in" "text": " and"
}, },
{ {
"id": 287, "id": 35622,
"logprob": -1.1630859,
"special": false,
"text": " cloud"
},
{
"id": 263,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " the" "text": "s"
},
{
"id": 20495,
"logprob": -0.32299805,
"special": false,
"text": " sky"
}, },
{ {
"id": 1, "id": 1,
@ -66,7 +72,8 @@
"special": true, "special": true,
"text": "</s>" "text": "</s>"
} }
] ],
"top_tokens": null
}, },
"generated_text": "Why is the sky blue?blue sky appeared in the sky" "generated_text": "Why is the sky blue?blue sky, clouds and clouds"
} }

View File

@ -11,13 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY", "location": "Brooklyn"
"num_days": 14 },
} "description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
@ -27,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795556, "created": 1712782670,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 37,
"prompt_tokens": 316, "prompt_tokens": 524,
"total_tokens": 345 "total_tokens": 561
} }
} }

View File

@ -11,13 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY", "location": "Brooklyn"
"num_days": 14 },
} "description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
@ -27,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795557, "created": 1712787937,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 37,
"prompt_tokens": 316, "prompt_tokens": 524,
"total_tokens": 345 "total_tokens": 561
} }
} }

View File

@ -11,12 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY" "location": "New York, NY"
} },
"description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
@ -26,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795557, "created": 1712852394,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 21, "completion_tokens": 48,
"prompt_tokens": 187, "prompt_tokens": 320,
"total_tokens": 208 "total_tokens": 368
} }
} }

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1712852597,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
}
}

View File

@ -19,9 +19,9 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1710795499, "created": 1712788218,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native" "system_fingerprint": "2.0.1-native"
} }

View File

@ -0,0 +1,42 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_chat_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_chat(flash_llama_chat_handle):
await flash_llama_chat_handle.health(300)
return flash_llama_chat_handle.client
@pytest.mark.private
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
response = await flash_llama_chat.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot

View File

@ -0,0 +1,109 @@
import pytest
import requests
import json
from aiohttp import ClientSession
from text_generation.types import (
Completion,
)
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
await flash_llama_completion_handle.health(300)
return flash_llama_completion_handle.client
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "Say this is a test",
"max_tokens": 5,
"seed": 0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert response == response_snapshot
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a"],
"max_tokens": 10,
"seed": 0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]]
all_indexes.sort()
assert all_indexes == [0, 1, 2, 3]
assert response == response_snapshot
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
request = {
"model": "tgi",
"prompt": [
"What color is the sky?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"stream": True,
}
url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = []
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(Completion(**c))
assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4
assert response.status == 200
assert chunks == response_snapshot

View File

@ -33,6 +33,9 @@ async def test_idefics(idefics, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert (
response.generated_text == " \nAssistant: A rooster stands"
), f"{repr(response.generated_text)}"
assert response == response_snapshot assert response == response_snapshot
@ -48,6 +51,9 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses] generated_texts = [r.generated_text for r in responses]
assert (
generated_texts[0] == " \nAssistant: A rooster stands"
), f"{response.generated_text}"
assert len(generated_texts) == 4 assert len(generated_texts) == 4
assert generated_texts, all( assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts] [text == generated_texts[0] for text in generated_texts]

View File

@ -0,0 +1,84 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_llava_next_handle(launcher):
with launcher(
"llava-hf/llava-v1.6-mistral-7b-hf",
num_shard=4,
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llava_next(flash_llava_next_handle):
await flash_llava_next_handle.health(300)
return flash_llava_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
chicken = get_chicken()
response = await flash_llava_next.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
)
assert (
response.generated_text == "\n\nOnce upon a time, there was a"
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
response = await flash_llava_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 6
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_load(
flash_llava_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_llava_next,
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == "\n\nOnce upon a time, there was a"
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot

View File

@ -45,7 +45,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
seed=0, seed=0,
) )
assert response.details.generated_tokens == 9 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == response_snapshot

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t5_sharded_handle(launcher): def t5_sharded_handle(launcher):
with launcher("google/flan-t5-xxl", num_shard=2) as handle: with launcher("google/flan-t5-xxl", num_shard=4) as handle:
yield handle yield handle

View File

@ -71,34 +71,7 @@ tools = [
] ]
@pytest.mark.asyncio @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.private
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0, "id": 0,
"type": "function", "type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0, "id": 0,
"type": "function", "type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice(
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "get_current_weather",
"parameters": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "New York, NY"},
}, },
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(
@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses: async for response in responses:
count += 1 count += 1
assert count == 20 assert count == 38
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert responses == response_snapshot

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "1.4.5" version = "2.0.1"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -9,7 +9,9 @@ homepage.workspace = true
[dependencies] [dependencies]
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
ctrlc = { version = "3.4.1", features = ["termination"] } ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] } nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0"
serde = { version = "1.0.188", features = ["derive"] } serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107" serde_json = "1.0.107"
tracing = "0.1.37" tracing = "0.1.37"

View File

@ -1,4 +1,5 @@
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use nix::sys::signal::{self, Signal}; use nix::sys::signal::{self, Signal};
use nix::unistd::Pid; use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
@ -19,17 +20,23 @@ use tracing_subscriber::EnvFilter;
mod env_runtime; mod env_runtime;
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
max_seq_len: Option<usize>,
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization { enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model: /// 4 bit quantization. Requires a specific AWQ quantized model:
/// https://hf.co/models?search=awq. /// <https://hf.co/models?search=awq>.
/// Should replace GPTQ models wherever possible because of the better latency /// Should replace GPTQ models wherever possible because of the better latency
Awq, Awq,
/// 8 bit quantization, doesn't require specific model. /// 8 bit quantization, doesn't require specific model.
/// Should be a drop-in replacement to bitsandbytes with much better performance. /// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq, Eetq,
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not. /// triton kernel (wider support) when it's not.
/// AWQ has faster kernels. /// AWQ has faster kernels.
@ -47,6 +54,11 @@ enum Quantization {
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
/// perplexity performance for you model /// perplexity performance for you model
BitsandbytesFP4, BitsandbytesFP4,
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
/// This dtype has native ops should be the fastest if available.
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
/// multiplication limitations.
Fp8,
} }
impl std::fmt::Display for Quantization { impl std::fmt::Display for Quantization {
@ -73,6 +85,9 @@ impl std::fmt::Display for Quantization {
Quantization::Eetq => { Quantization::Eetq => {
write!(f, "eetq") write!(f, "eetq")
} }
Quantization::Fp8 => {
write!(f, "fp8")
}
} }
} }
} }
@ -206,8 +221,13 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle. /// Please note that some models have a finite range of sequence they can handle.
#[clap(default_value = "1024", long, env)] /// Default to min(max_position_embeddings - 1, 4095)
max_input_length: usize, #[clap(long, env)]
max_input_tokens: Option<usize>,
/// Legacy version of [`Args::max_input_tokens`].
#[clap(long, env)]
max_input_length: Option<usize>,
/// This is the most important value to set as it defines the "memory budget" /// This is the most important value to set as it defines the "memory budget"
/// of running clients requests. /// of running clients requests.
@ -217,8 +237,9 @@ struct Args {
/// `1511` max_new_tokens. /// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM /// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be. /// and the less effective batching can be.
#[clap(default_value = "2048", long, env)] /// Default to min(max_position_embeddings, 4096)
max_total_tokens: usize, #[clap(long, env)]
max_total_tokens: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where /// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting /// you want to start considering pausing the running queries to include the waiting
@ -236,8 +257,9 @@ struct Args {
/// Limits the number of tokens for the prefill operation. /// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting /// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent. /// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)] /// Default to `max_input_tokens + 50` to give a bit of room.
max_batch_prefill_tokens: u32, #[clap(long, env)]
max_batch_prefill_tokens: Option<u32>,
/// **IMPORTANT** This is one critical control to allow maximum usage /// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware. /// of the available hardware.
@ -286,13 +308,9 @@ struct Args {
/// Specify the batch sizes to compute cuda graphs for. /// Specify the batch sizes to compute cuda graphs for.
/// Use "0" to disable. /// Use "0" to disable.
#[clap( /// Default = "1,2,4,8,16,32"
long, #[clap(long, env, value_delimiter = ',')]
env, cuda_graphs: Option<Vec<usize>>,
value_delimiter = ',',
default_value = "1,2,4,8,16,32,64,96,128"
)]
cuda_graphs: Vec<usize>,
/// The IP address to listen on /// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
@ -396,6 +414,10 @@ struct Args {
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
} }
#[derive(Debug)] #[derive(Debug)]
@ -499,6 +521,9 @@ fn shard_manager(
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Remove LOG_LEVEL if present
envs.retain(|(name, _)| name != "LOG_LEVEL");
// Torch Distributed Env vars // Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into())); envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -586,6 +611,7 @@ fn shard_manager(
tracing::info!("Starting shard"); tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server") let mut p = match Command::new("text-generation-server")
.args(shard_args) .args(shard_args)
.env_clear()
.envs(envs) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
@ -796,6 +822,14 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
impl core::fmt::Display for LauncherError {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for LauncherError {}
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
// Enter download tracing span // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
@ -824,6 +858,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Remove LOG_LEVEL if present
envs.retain(|(name, _)| name != "LOG_LEVEL");
// Disable progress bar // Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
@ -858,6 +895,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Command::new("text-generation-server") let mut download_process = match Command::new("text-generation-server")
.args(download_args) .args(download_args)
.env_clear()
.envs(envs) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
@ -928,6 +966,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
fn spawn_shards( fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
cuda_graphs: Vec<usize>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>, shutdown_sender: mpsc::Sender<()>,
@ -955,11 +994,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma; let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta; let watermark_delta = args.watermark_delta;
let cuda_graphs: Vec<usize> = args let cuda_graphs_clone = cuda_graphs.clone();
.cuda_graphs
.iter()
.filter_map(|&c| if c > 0 { Some(c) } else { None })
.collect();
let cuda_memory_fraction = args.cuda_memory_fraction; let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
@ -981,7 +1016,7 @@ fn spawn_shards(
disable_custom_kernels, disable_custom_kernels,
watermark_gamma, watermark_gamma,
watermark_delta, watermark_delta,
cuda_graphs, cuda_graphs_clone,
cuda_memory_fraction, cuda_memory_fraction,
rope_scaling, rope_scaling,
rope_factor, rope_factor,
@ -1037,6 +1072,9 @@ fn compute_type(num_shard: usize) -> Option<String> {
fn spawn_webserver( fn spawn_webserver(
num_shard: usize, num_shard: usize,
args: Args, args: Args,
max_input_tokens: usize,
max_total_tokens: usize,
max_batch_prefill_tokens: u32,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Child, LauncherError> { ) -> Result<Child, LauncherError> {
@ -1044,6 +1082,8 @@ fn spawn_webserver(
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut router_args = vec![ let mut router_args = vec![
"--max-client-batch-size".to_string(),
args.max_client_batch_size.to_string(),
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(), args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(), "--max-best-of".to_string(),
@ -1052,12 +1092,12 @@ fn spawn_webserver(
args.max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(), "--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(), args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(), "--max-input-tokens".to_string(),
args.max_input_length.to_string(), max_input_tokens.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
args.max_total_tokens.to_string(), max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(), "--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(), max_batch_prefill_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -1209,16 +1249,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
} }
fn main() -> Result<(), LauncherError> { fn main() -> Result<(), LauncherError> {
match Command::new("ldconfig").spawn() {
Ok(_) => {}
Err(err) => {
tracing::warn!(
"Unable to refresh ldconfig cache. Skipping (useless in most cases). Details {:?}",
err
)
}
}
// Pattern match configuration // Pattern match configuration
let args: Args = Args::parse(); let args: Args = Args::parse();
@ -1245,18 +1275,128 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:?}", args); tracing::info!("{:?}", args);
// Validate args let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
if args.max_input_length >= args.max_total_tokens { let model_id = args.model_id.clone();
return Err(LauncherError::ArgumentValidation( let mut path = std::path::Path::new(&args.model_id).to_path_buf();
"`max_input_length` must be < `max_total_tokens`".to_string(), let filename = if !path.exists() {
)); // Assume it's a hub id
let api = Api::new()?;
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
max_default
} else {
max_position_embeddings
} }
if args.max_input_length as u32 > args.max_batch_prefill_tokens { }
return Err(LauncherError::ArgumentValidation(format!( _ => {
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", return Err(Box::new(LauncherError::ArgumentValidation(
args.max_batch_prefill_tokens, args.max_input_length "no max defined".to_string(),
))); )));
} }
};
Ok(max_position_embeddings)
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
(Some(max_input_tokens), Some(max_input_length)) => {
return Err(LauncherError::ArgumentValidation(
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
)));
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => {
let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
}
};
let max_total_tokens = {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
}
}
};
let max_batch_prefill_tokens = {
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
max_batch_size * max_input_tokens
} else {
// Adding some edge in order to account for potential block_size alignement
// issue.
max_input_tokens + 50
} as u32;
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
}
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
#[allow(deprecated)]
(
None,
Some(
Quantization::Bitsandbytes
| Quantization::BitsandbytesNF4
| Quantization::BitsandbytesFP4,
),
) => {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs
}
};
if args.validation_workers == 0 { if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation( return Err(LauncherError::ArgumentValidation(
@ -1276,16 +1416,16 @@ fn main() -> Result<(), LauncherError> {
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens { if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens max_batch_prefill_tokens, max_batch_total_tokens
))); )));
} }
if args.max_total_tokens as u32 > *max_batch_total_tokens { if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens max_total_tokens, max_batch_total_tokens
))); )));
} }
} }
@ -1332,6 +1472,7 @@ fn main() -> Result<(), LauncherError> {
spawn_shards( spawn_shards(
num_shard, num_shard,
&args, &args,
cuda_graphs,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
shutdown_sender, shutdown_sender,
@ -1346,7 +1487,15 @@ fn main() -> Result<(), LauncherError> {
return Ok(()); return Ok(());
} }
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver) let mut webserver = spawn_webserver(
num_shard,
args,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
shutdown.clone(),
&shutdown_receiver,
)
.map_err(|err| { .map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err err

View File

@ -21,7 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { version = "0.3.0", features = ["tokio"] } hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { version = "0.12.1", features = [] }
@ -33,7 +33,7 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { version = "0.15.1", features = ["http"] } tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] } tower-http = { version = "0.4.4", features = ["cors"] }
@ -44,10 +44,12 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" } minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3" regex = "1.10.3"
once_cell = "1.19.0" once_cell = "1.19.0"
image = "0.25.1"
base64 = "0.22.0"
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -112,10 +112,15 @@ impl Client {
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str("![](");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
requests.push(Request { requests.push(Request {
id: 0, id: 0,
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize), inputs,
truncate, truncate,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {

158
router/src/config.rs Normal file
View File

@ -0,0 +1,158 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
text_config: TextConfig,
vision_config: VisionConfig,
image_grid_pinpoints: Vec<(usize, usize)>,
}
fn get_anyres_image_grid_shape(
height: usize,
width: usize,
grid_pinpoints: &[(usize, usize)],
patch_size: usize,
) -> (usize, usize) {
let (height, width) = select_best_resolution(height, width, grid_pinpoints);
(height / patch_size, width / patch_size)
}
/// Selects the best resolution from a list of possible resolutions based on the original size.
/// This is done by calculating the effective and wasted resolution for each possible resolution.
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
fn select_best_resolution(
original_height: usize,
original_width: usize,
possible_resolutions: &[(usize, usize)],
) -> (usize, usize) {
let mut best_fit = None;
let mut max_effective_resolution = 0;
let mut min_wasted_resolution = f32::NEG_INFINITY;
for (height, width) in possible_resolutions {
let wscale = *width as f32 / original_width as f32;
let hscale = *height as f32 / original_height as f32;
// f32 partial ord.
let scale = if wscale > hscale { hscale } else { wscale };
let downscaled_width = (*width as f32 * scale) as usize;
let downscaled_height = (*height as f32 * scale) as usize;
let effective_resolution = std::cmp::min(
downscaled_width * downscaled_height,
original_width * original_height,
);
let wasted_resolution = (width * height) - effective_resolution;
if effective_resolution > max_effective_resolution
|| (effective_resolution == max_effective_resolution
&& (wasted_resolution as f32) < min_wasted_resolution)
{
max_effective_resolution = effective_resolution;
min_wasted_resolution = wasted_resolution as f32;
best_fit = Some((*height, *width));
}
}
best_fit.unwrap_or((original_height, original_width))
}
impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size;
let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width;
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
// The base patch covers the entire image
let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
image_size: usize,
patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Ssm,
GptBigcode,
Santacoder,
Bloom,
Mpt,
GptNeox,
Phi,
#[serde(rename = "phi-msft")]
PhiMsft,
Llama,
Baichuan,
Gemma,
Cohere,
Drbx,
Falcon,
Mixtral,
Starcoder2,
Qwen2,
Opt,
T5,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextConfig {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
image_size: usize,
patch_size: usize,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_llava_next_features() {
let config = LlavaNext {
text_config: TextConfig {},
vision_config: VisionConfig {
image_size: 336,
patch_size: 14,
},
image_grid_pinpoints: vec![
(336, 672),
(672, 336),
(672, 672),
(1008, 336),
(336, 1008),
],
};
let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
}
}

View File

@ -1,12 +1,15 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
Message, PrefillToken, Queue, Token, HubTokenizerConfig, Message, PrefillToken, Queue, Token,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
@ -86,7 +89,18 @@ impl Infer {
let chat_template = tokenizer_config let chat_template = tokenizer_config
.chat_template .chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); .and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| {
// .strip() is not supported in minijinja
let t = t.replace(".strip()", " | trim");
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -174,11 +188,15 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> { pub(crate) fn apply_chat_template(
&self,
messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
self.chat_template self.chat_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages) .apply(messages, grammar_with_prompt)
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}"); tracing::error!("{e}");
@ -311,6 +329,7 @@ struct ChatTemplate {
template: Template<'static, 'static>, template: Template<'static, 'static>,
bos_token: Option<String>, bos_token: Option<String>,
eos_token: Option<String>, eos_token: Option<String>,
use_default_tool_template: bool,
} }
impl ChatTemplate { impl ChatTemplate {
@ -318,6 +337,10 @@ impl ChatTemplate {
let mut env = Box::new(Environment::new()); let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance. // leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env) let template = Box::leak(env)
.template_from_str(Box::leak(template_str)) .template_from_str(Box::leak(template_str))
@ -327,21 +350,159 @@ impl ChatTemplate {
template, template,
bos_token, bos_token,
eos_token, eos_token,
use_default_tool_template,
}
}
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content = Some(format!(
"{}\n---\n{}\n{}",
last_message.content.as_deref().unwrap_or_default(),
tool_prompt,
tools
));
}
} }
} }
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
self.template self.template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
messages, messages,
bos_token: self.bos_token.as_deref(), bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(), eos_token: self.eos_token.as_deref(),
add_generation_prompt: true, add_generation_prompt: true,
tools: None,
tools_prompt: None,
}) })
.map_err(InferError::TemplateError) .map_err(InferError::TemplateError)
} }
} }
pub struct ToolGrammar {}
impl ToolGrammar {
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
}
}
/// Batching logic /// Batching logic
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
@ -757,6 +918,8 @@ pub enum InferError {
IncompleteGeneration, IncompleteGeneration,
#[error("Template error: {0}")] #[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error), TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")]
ToolError(String),
} }
impl InferError { impl InferError {
@ -767,6 +930,7 @@ impl InferError {
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error",
} }
} }
} }
@ -838,6 +1002,7 @@ mod tests {
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: true, add_generation_prompt: true,
..Default::default()
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -913,6 +1078,7 @@ mod tests {
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: true, add_generation_prompt: true,
..Default::default()
}; };
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -987,6 +1153,7 @@ mod tests {
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: true, add_generation_prompt: true,
..Default::default()
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -1045,6 +1212,7 @@ mod tests {
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: true, add_generation_prompt: true,
..Default::default()
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -1104,6 +1272,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some(""), eos_token: Some(""),
..Default::default()
}, },
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
}, },
@ -1115,6 +1284,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>", target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
}, },
@ -1126,6 +1296,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>", target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
}, },
@ -1137,6 +1308,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>", target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
}, },
@ -1148,6 +1320,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"),
..Default::default()
}, },
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
}, },
@ -1159,8 +1332,9 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"),
..Default::default()
}, },
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "llama", name: "llama",
@ -1171,8 +1345,9 @@ mod tests {
add_generation_prompt: true, add_generation_prompt: true,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]" target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "whisper", name: "whisper",
@ -1182,10 +1357,10 @@ mod tests {
add_generation_prompt: true, add_generation_prompt: true,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"),
..Default::default()
},
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
}, },
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
}
]; ];
#[allow(unused_variables)] // name is unused #[allow(unused_variables)] // name is unused
@ -1211,7 +1386,8 @@ mod tests {
messages: example_chat_with_system.clone(), messages: example_chat_with_system.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>") eos_token: Some("</s>"),
..Default::default()
}, },
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>", target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
}, },
@ -1236,8 +1412,9 @@ mod tests {
add_generation_prompt: true, add_generation_prompt: true,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>" target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
@ -1247,6 +1424,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<bos>"), bos_token: Some("<bos>"),
eos_token: Some("<eos>"), eos_token: Some("<eos>"),
..Default::default()
}, },
target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
}, },
@ -1258,8 +1436,9 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]" target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "mistralai/Mixtral-8x7B-Instruct-v0.1", name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
@ -1269,6 +1448,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]", target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]",
}, },
@ -1280,6 +1460,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
}, },
@ -1292,6 +1473,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>", target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>",
}, },
@ -1303,6 +1485,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>", target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
}, },
@ -1315,6 +1498,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ", target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ",
}, },
@ -1326,6 +1510,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!", target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!",
}, },
@ -1337,6 +1522,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!", target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!",
}, },
@ -1348,6 +1534,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<begin▁of▁sentence>"), bos_token: Some("<begin▁of▁sentence>"),
eos_token: Some("<end▁of▁sentence>"), eos_token: Some("<end▁of▁sentence>"),
..Default::default()
}, },
target: "<begin▁of▁sentence>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<end▁of▁sentence>User: I'd like to show off how chat templating works!\n\n", target: "<begin▁of▁sentence>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<end▁of▁sentence>User: I'd like to show off how chat templating works!\n\n",
}, },
@ -1359,8 +1546,9 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>" target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "internlm/internlm2-chat-7b", name: "internlm/internlm2-chat-7b",
@ -1370,6 +1558,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
}, },
@ -1381,6 +1570,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<begin▁of▁sentence>"), bos_token: Some("<begin▁of▁sentence>"),
eos_token: Some("<|EOT|>"), eos_token: Some("<|EOT|>"),
..Default::default()
}, },
target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n", target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n",
}, },
@ -1393,6 +1583,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<|endoftext|>"), bos_token: Some("<|endoftext|>"),
eos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"),
..Default::default()
}, },
target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!", target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!",
}, },
@ -1404,6 +1595,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]", target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
}, },
@ -1415,6 +1607,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!", target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!",
}, },
@ -1426,6 +1619,7 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<begin▁of▁sentence>"), bos_token: Some("<begin▁of▁sentence>"),
eos_token: Some("</EOT>"), eos_token: Some("</EOT>"),
..Default::default()
}, },
target: "<begin▁of▁sentence>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n", target: "<begin▁of▁sentence>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n",
}, },
@ -1441,9 +1635,10 @@ mod tests {
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
..Default::default()
}, },
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
} },
]; ];
#[allow(unused_variables)] // name is unused #[allow(unused_variables)] // name is unused

View File

@ -1,3 +1,4 @@
pub mod config;
mod health; mod health;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer; mod infer;
@ -48,9 +49,22 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>, pub pipeline_tag: Option<String>,
} }
#[derive(Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ChatTemplate {
name: String,
template: String,
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ChatTemplateVersions {
Single(String),
Multiple(Vec<ChatTemplate>),
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<ChatTemplateVersions>,
pub completion_template: Option<String>, pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>, pub bos_token: Option<String>,
@ -65,7 +79,7 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/). /// A string that represents a [JSON Schema](https://json-schema.org/).
@ -141,6 +155,8 @@ pub struct Info {
pub max_batch_size: Option<usize>, pub max_batch_size: Option<usize>,
#[schema(example = "2")] #[schema(example = "2")]
pub validation_workers: usize, pub validation_workers: usize,
#[schema(example = "32")]
pub max_client_batch_size: usize,
/// Router Info /// Router Info
#[schema(example = "0.5.0")] #[schema(example = "0.5.0")]
pub version: &'static str, pub version: &'static str,
@ -222,7 +238,7 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "false")]
pub decoder_input_details: bool, pub decoder_input_details: bool,
#[serde(default)] #[serde(default)]
#[schema( #[schema(
@ -236,6 +252,7 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>, pub grammar: Option<GrammarType>,
} }
@ -266,6 +283,34 @@ fn default_parameters() -> GenerateParameters {
} }
} }
mod prompt_serde {
use serde::{self, Deserialize, Deserializer};
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(vec![s]),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
"Empty array detected. Do not use an empty array for the prompt.",
)),
Value::Array(arr) => arr
.iter()
.map(|v| match v {
Value::String(s) => Ok(s.to_owned()),
_ => Err(serde::de::Error::custom("Expected a string")),
})
.collect(),
_ => Err(serde::de::Error::custom(
"Expected a string or an array of strings",
)),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub struct CompletionRequest { pub struct CompletionRequest {
/// UNUSED /// UNUSED
@ -275,7 +320,8 @@ pub struct CompletionRequest {
/// The prompt to generate completions for. /// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub prompt: String, #[serde(deserialize_with = "prompt_serde::deserialize")]
pub prompt: Vec<String>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default)]
@ -655,7 +701,7 @@ pub(crate) struct ChatRequest {
#[serde(default = "default_tool_prompt")] #[serde(default = "default_tool_prompt")]
#[schema( #[schema(
nullable = true, nullable = true,
example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
)] )]
pub tool_prompt: Option<String>, pub tool_prompt: Option<String>,
@ -668,7 +714,7 @@ pub(crate) struct ChatRequest {
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {
Some( Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
) )
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
@ -713,26 +759,26 @@ mod deserialize_tool_choice {
} }
} }
#[derive(Debug, Deserialize, Serialize, ToSchema)] #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools { pub struct Tools {
#[serde(flatten)] #[serde(flatten)]
functions_map: FunctionsMap, functions_map: FunctionsMap,
properties: Properties, properties: Properties,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionsMap { struct FunctionsMap {
#[serde(rename = "$functions")] #[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>, functions: std::collections::HashMap<String, serde_json::Value>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionRef { struct FunctionRef {
#[serde(rename = "$ref")] #[serde(rename = "$ref")]
ref_path: String, ref_path: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Properties { struct Properties {
#[serde(serialize_with = "serialize_function")] #[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>, function: Vec<FunctionRef>,
@ -753,7 +799,8 @@ pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
pub parameters: serde_json::Value, #[serde(alias = "parameters")]
pub arguments: serde_json::Value,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
@ -765,12 +812,14 @@ pub(crate) struct Tool {
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<Message>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<&'a str>,
tools_prompt: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
@ -977,7 +1026,10 @@ mod tests {
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens // check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string())); assert_eq!(
config.chat_template,
Some(ChatTemplateVersions::Single("test".to_string()))
);
assert_eq!( assert_eq!(
config.bos_token, config.bos_token,
Some("<begin▁of▁sentence>".to_string()) Some("<begin▁of▁sentence>".to_string())
@ -1009,7 +1061,10 @@ mod tests {
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens // check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string())); assert_eq!(
config.chat_template,
Some(ChatTemplateVersions::Single("test".to_string()))
);
assert_eq!( assert_eq!(
config.bos_token, config.bos_token,
Some("<begin▁of▁sentence>".to_string()) Some("<begin▁of▁sentence>".to_string())

View File

@ -13,6 +13,7 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -34,7 +35,7 @@ struct Args {
#[clap(default_value = "5", long, env)] #[clap(default_value = "5", long, env)]
max_top_n_tokens: u32, max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_tokens: usize,
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
@ -77,6 +78,8 @@ struct Args {
messages_api_enabled: bool, messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
} }
#[tokio::main] #[tokio::main]
@ -89,7 +92,7 @@ async fn main() -> Result<(), RouterError> {
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
@ -111,19 +114,20 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
// Validate args // Validate args
if max_input_length >= max_total_tokens { if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_length as u32 > max_batch_prefill_tokens { if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
} }
if validation_workers == 0 { if validation_workers == 0 {
@ -191,15 +195,19 @@ async fn main() -> Result<(), RouterError> {
}; };
// Load tokenizer and model info // Load tokenizer and model info
let (tokenizer, model_info) = if local_model { let (tokenizer, model_info, config) = if local_model {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
let model_info = HubModelInfo { let model_info = HubModelInfo {
model_id: tokenizer_name.to_string(), model_id: tokenizer_name.to_string(),
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}; };
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
.ok()
.as_ref()
.and_then(|c| serde_json::from_str(c).ok());
(tokenizer, model_info) (tokenizer, model_info, config)
} else if let Some(api) = api.clone() { } else if let Some(api) = api.clone() {
let api_repo = api.repo(Repo::with_revision( let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(), tokenizer_name.to_string(),
@ -212,6 +220,19 @@ async fn main() -> Result<(), RouterError> {
Err(_) => get_base_tokenizer(&api, &api_repo).await, Err(_) => get_base_tokenizer(&api, &api_repo).await,
}; };
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { HubModelInfo {
@ -221,7 +242,7 @@ async fn main() -> Result<(), RouterError> {
} }
}); });
(tokenizer, model_info) (tokenizer, model_info, config)
} else { } else {
// No API and no local model // No API and no local model
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
@ -229,6 +250,8 @@ async fn main() -> Result<(), RouterError> {
)); ));
}; };
tracing::info!("Using config {config:?}");
// Load tokenizer config if found locally, or check if we can get it from the API if needed // Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if let Some(path) = tokenizer_config_path { let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path"); tracing::info!("Using local tokenizer config from user specified path");
@ -291,7 +314,7 @@ async fn main() -> Result<(), RouterError> {
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup( .warmup(
max_input_length as u32, max_input_tokens as u32,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_total_tokens as u32, max_total_tokens as u32,
max_batch_size, max_batch_size,
@ -354,7 +377,7 @@ async fn main() -> Result<(), RouterError> {
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
@ -363,6 +386,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_size, max_batch_size,
sharded_client, sharded_client,
tokenizer, tokenizer,
config,
validation_workers, validation_workers,
addr, addr,
cors_allow_origin, cors_allow_origin,
@ -372,6 +396,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_config, tokenizer_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size,
) )
.await?; .await?;
Ok(()) Ok(())
@ -381,12 +406,15 @@ async fn main() -> Result<(), RouterError> {
/// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) /// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) /// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) { fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
let mut layers = Vec::new(); let mut layers = Vec::new();
// STDOUT/STDERR layer // STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer() let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true) .with_file(true)
.with_ansi(ansi)
.with_line_number(true); .with_line_number(true);
let fmt_layer = match json_output { let fmt_layer = match json_output {

View File

@ -190,16 +190,22 @@ impl State {
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
if self.entries.is_empty() { if self.entries.is_empty() {
tracing::debug!("No queue");
return None; return None;
} }
// Check if we have enough entries // Check if we have enough entries
if let Some(min_size) = min_size { if let Some(min_size) = min_size {
if self.entries.len() < min_size { if self.entries.len() < min_size {
tracing::debug!("Not enough entries");
return None; return None;
} }
} }
// Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(&Span::current());
@ -218,6 +224,7 @@ impl State {
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
tracing::debug!("Dropping entry");
continue; continue;
} }
@ -254,10 +261,12 @@ impl State {
{ {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break; break;
} }
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships // Add relationships
@ -288,6 +297,7 @@ impl State {
// Empty batch // Empty batch
if batch_requests.is_empty() { if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None; return None;
} }

View File

@ -1,6 +1,7 @@
use crate::config::Config;
/// HTTP Server logic /// HTTP Server logic
use crate::health::Health; use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -14,7 +15,8 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use crate::{FunctionDefinition, ToolCall, ToolType};
use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -22,20 +24,21 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{http, Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::stream::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal; use tokio::signal;
use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
@ -161,10 +164,20 @@ async fn generate(
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
}
async fn generate_internal(
infer: Extension<Infer>,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.inputs); // Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;
@ -358,12 +371,13 @@ async fn generate_stream(
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let span = tracing::Span::current();
let on_message_callback = |stream_token: StreamResponse| { let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await; generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
@ -373,8 +387,8 @@ async fn generate_stream_internal(
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event, on_message_callback: impl Fn(StreamResponse) -> Event,
span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -547,7 +561,11 @@ async fn generate_stream_internal(
path = "/v1/completions", path = "/v1/completions",
request_body = CompletionRequest, request_body = CompletionRequest,
responses( responses(
(status = 200, description = "Generated Text", body = ChatCompletionChunk), (status = 200, description = "Generated Chat Completion",
content(
("application/json" = Completion),
("text/event-stream" = CompletionCompleteChunk),
)),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
@ -576,6 +594,7 @@ async fn completions(
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<CompletionRequest>, Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let stream = req.stream;
@ -595,9 +614,25 @@ async fn completions(
)); ));
} }
// build the request passing some parameters if req.prompt.len() > info.max_client_batch_size {
let generate_request = GenerateRequest { metrics::increment_counter!("tgi_request_failure", "err" => "validation");
inputs: req.prompt.to_string(), return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: format!(
"Number of prompts exceeds the maximum allowed batch size of {}",
info.max_client_batch_size
),
error_type: "batch size exceeded".to_string(),
}),
));
}
let generate_requests: Vec<GenerateRequest> = req
.prompt
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,
@ -618,9 +653,25 @@ async fn completions(
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
}, },
}; })
.collect();
let mut x_compute_type = None;
let mut x_compute_characters = 0u32;
let mut x_accel_buffering = None;
if stream { if stream {
let mut response_streams = FuturesOrdered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let model_id = info.model_id.clone();
let system_fingerprint =
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
// Create a future for each generate_stream_internal call.
let generate_future = async move {
let on_message_callback = move |stream_token: StreamResponse| { let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
@ -637,50 +688,158 @@ async fn completions(
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
finish_reason: "".to_string(), finish_reason: "".to_string(),
index: 0, index: index as u32,
logprobs: None, logprobs: None,
text: stream_token.token.text, text: stream_token.token.text,
}], }],
model: info.model_id.clone(), model: model_id.clone(),
system_fingerprint: format!( system_fingerprint: system_fingerprint.clone(),
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
}) })
.map_or_else( .map_or_else(|_e| Event::default(), |data| data)
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
},
|data| data,
)
}; };
let (headers, response_stream) = generate_stream_internal( let (header_tx, header_rx) = oneshot::channel();
infer, let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
compute_type,
tokio::spawn(async move {
let (header_map, sse) = generate_stream_internal(
infer_clone.clone(),
compute_type_clone.clone(),
Json(generate_request), Json(generate_request),
on_message_callback, on_message_callback,
span_clone.clone(),
) )
.await; .await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); // send and dont wait for response
let _ = header_tx.send(header_map);
// pin an emit messages to the sse_tx
let mut sse = Box::pin(sse);
while let Some(event) = sse.next().await {
if sse_tx.send(event).is_err() {
tracing::error!("Failed to send event. Receiver dropped.");
break;
}
}
});
(header_rx, sse_rx)
};
response_streams.push_back(generate_future);
}
let mut all_rxs = vec![];
while let Some((header_rx, sse_rx)) = response_streams.next().await {
all_rxs.push(sse_rx);
// get the headers from the first response of each stream
let headers = header_rx.await.map_err(|e| {
tracing::error!("Failed to get headers: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to get headers".to_string(),
error_type: "headers".to_string(),
}),
)
})?;
if x_compute_type.is_none() {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
x_accel_buffering = headers
.get("x-accel-buffering")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
}
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
.unwrap_or(0);
}
let mut headers = HeaderMap::new();
if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into());
if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
// now sink the sse streams into a single stream and remove the ones that are done
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
loop {
let mut i = 0;
while i < all_rxs.len() {
let rx = &mut all_rxs[i];
select! {
Some(event) = rx.recv() => {
yield event;
}
else => {
all_rxs.remove(i);
continue; // skip the increment to handle the next element at the same index
}
}
i += 1; // only increment when no element was removed
}
if all_rxs.is_empty() {
break;
}
}
};
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = generate(
Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let responses = FuturesUnordered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
let response_future = async move {
let result = generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await;
result.map(|(headers, generation)| (index, headers, generation))
};
responses.push(response_future);
}
let generate_responses = responses.try_collect::<Vec<_>>().await?;
let mut prompt_tokens = 0u32;
let mut completion_tokens = 0u32;
let mut total_tokens = 0u32;
let mut x_compute_time = 0u32;
let mut x_total_time = 0u32;
let mut x_validation_time = 0u32;
let mut x_queue_time = 0u32;
let mut x_inference_time = 0u32;
let mut x_time_per_token = 0u32;
let mut x_prompt_tokens = 0u32;
let mut x_generated_tokens = 0u32;
let choices = generate_responses
.into_iter()
.map(|(index, headers, Json(generation))| {
let details = generation.details.ok_or(( let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly // this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
@ -690,6 +849,65 @@ async fn completions(
}), }),
))?; ))?;
if x_compute_type.is_none() {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
}
// accumulate headers and usage from each response
x_compute_time += headers
.get("x-compute-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_total_time += headers
.get("x-total-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_validation_time += headers
.get("x-validation-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_queue_time += headers
.get("x-queue-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_inference_time += headers
.get("x-inference-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_time_per_token += headers
.get("x-time-per-token")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_prompt_tokens += headers
.get("x-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_generated_tokens += headers
.get("x-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
prompt_tokens += details.prefill.len() as u32;
completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
Ok(CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: generation.generated_text,
})
})
.collect::<Result<Vec<_>, _>>()
.map_err(|(status, Json(err))| (status, Json(err)))?;
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
@ -700,19 +918,30 @@ async fn completions(
info.version, info.version,
info.docker_label.unwrap_or("native") info.docker_label.unwrap_or("native")
), ),
choices: vec![CompletionComplete { choices,
finish_reason: details.finish_reason.to_string(),
index: 0,
logprobs: None,
text: generation.generated_text,
}],
usage: Usage { usage: Usage {
prompt_tokens: details.prefill.len() as u32, prompt_tokens,
completion_tokens: details.generated_tokens, completion_tokens,
total_tokens: details.prefill.len() as u32 + details.generated_tokens, total_tokens,
}, },
}; };
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into());
headers.insert("x-queue-time", x_queue_time.into());
headers.insert("x-inference-time", x_inference_time.into());
headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into());
if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())
} }
} }
@ -724,7 +953,11 @@ async fn completions(
path = "/v1/chat/completions", path = "/v1/chat/completions",
request_body = ChatRequest, request_body = ChatRequest,
responses( responses(
(status = 200, description = "Generated Text", body = ChatCompletionChunk), (status = 200, description = "Generated Chat Completion",
content(
("application/json" = ChatCompletion),
("text/event-stream" = ChatCompletionChunk),
)),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
@ -753,21 +986,32 @@ async fn chat_completions(
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let ChatRequest {
let max_new_tokens = req.max_tokens.or(Some(100)); logprobs,
let repetition_penalty = req max_tokens,
.presence_penalty messages,
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0) presence_penalty,
.map(|x| x + 2.0); seed,
let logprobs = req.logprobs.unwrap_or(false); stop,
let seed = req.seed; stream,
let stop = req.stop.unwrap_or_default(); tools,
tool_choice,
tool_prompt,
..
} = req;
// apply chat template to flatten the request into a single input let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let mut inputs = match infer.apply_chat_template(req.messages) { let max_new_tokens = max_tokens.or(Some(100));
Ok(inputs) => inputs, let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default();
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
@ -781,60 +1025,28 @@ async fn chat_completions(
} }
}; };
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { let grammar_with_prompt = tool_grammar
let tool_prompt = req.tool_prompt.unwrap_or_default(); .as_ref()
let tools_to_use = match tool_choice { .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
ToolType::FunctionName(name) => {
vec![req_tools let typed_grammar = grammar_with_prompt
.iter() .as_ref()
.find(|tool| tool.function.name == *name) .map(|(grammar, _)| grammar.clone());
.ok_or_else(|| {
( // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(), error: err.to_string(),
error_type: "Tool not found".to_string(), error_type: err.error_type().to_string(),
}), }),
) ));
})?
.clone()]
} }
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
}; };
// build the request passing some parameters // build the request passing some parameters
@ -858,7 +1070,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: tool_grammar.clone(), grammar: typed_grammar,
}, },
}; };
@ -912,17 +1124,14 @@ async fn chat_completions(
compute_type, compute_type,
Json(generate_request), Json(generate_request),
on_message_callback, on_message_callback,
span,
) )
.await; .await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = generate( let (headers, Json(generation)) =
Extension(infer), generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -941,27 +1150,28 @@ async fn chat_completions(
}), }),
) )
})?; })?;
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: 0, id: 0,
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,
name: "tools".to_string(), name: gen_text_value
parameters: gen_text_value.get("function").map_or_else( .get("function")
|| { .and_then(|f| f.get("_name"))
serde_json::from_str(&generation.generated_text).map_err(|e| { .and_then(|name| name.as_str())
( .unwrap_or("default_function_name")
StatusCode::UNPROCESSABLE_ENTITY, .to_string(),
Json(ErrorResponse { // Serialize the JSON object obtained from "function" to an escaped JSON string
error: e.to_string(), arguments: gen_text_value
error_type: "Input validation error".to_string(), .get("function")
}), .map(|f| {
) let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
}) })
}, .unwrap_or_default(),
|f| Ok(f.clone()),
)?,
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
@ -1018,6 +1228,7 @@ async fn vertex_compatibility(
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>, Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
// check that theres at least one instance // check that theres at least one instance
@ -1049,10 +1260,11 @@ async fn vertex_compatibility(
}; };
async { async {
generate( generate_internal(
Extension(infer.clone()), Extension(infer.clone()),
Extension(compute_type.clone()), compute_type.clone(),
Json(generate_request), Json(generate_request),
span.clone(),
) )
.await .await
.map(|(_, Json(generation))| generation.generated_text) .map(|(_, Json(generation))| generation.generated_text)
@ -1154,6 +1366,7 @@ pub async fn run(
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
client: ShardedClient, client: ShardedClient,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
config: Option<Config>,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
@ -1163,6 +1376,7 @@ pub async fn run(
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -1236,6 +1450,7 @@ pub async fn run(
let validation = Validation::new( let validation = Validation::new(
validation_workers, validation_workers,
tokenizer, tokenizer,
config,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
@ -1336,6 +1551,7 @@ pub async fn run(
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
validation_workers, validation_workers,
max_client_batch_size,
version: env!("CARGO_PKG_VERSION"), version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"), sha: option_env!("VERGEN_GIT_SHA"),
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
@ -1535,6 +1751,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
}; };
( (

View File

@ -1,15 +1,19 @@
use crate::config::Config;
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor;
use text_generation_client::{ use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
}; };
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection; // use tokenizers::TruncationDirection;
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{instrument, Span}; use tracing::{instrument, Span};
@ -34,6 +38,7 @@ impl Validation {
pub(crate) fn new( pub(crate) fn new(
workers: usize, workers: usize,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
config: Option<Config>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
@ -50,12 +55,13 @@ impl Validation {
// Create workers // Create workers
for _ in 0..workers { for _ in 0..workers {
let tokenizer_clone = tokenizer.clone(); let tokenizer_clone = tokenizer.clone();
let config_clone = config.clone();
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
senders.push(tokenizer_sender); senders.push(tokenizer_sender);
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone, tokenizer_receiver) tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
}); });
} }
@ -155,14 +161,17 @@ impl Validation {
} else { } else {
return Err(ValidationError::UnsetMaxNewTokens); return Err(ValidationError::UnsetMaxNewTokens);
}; };
let input_length = truncate.unwrap_or(self.max_input_length); let mut input_length = truncate.unwrap_or(self.max_input_length);
// We don't have a tokenizer, therefore we have no idea how long is the query, let
// them through and hope for the best.
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens( input_length = input_length.saturating_sub(max_new_tokens as usize);
self.max_total_tokens - self.max_input_length, // return Err(ValidationError::MaxNewTokens(
max_new_tokens, // self.max_total_tokens - self.max_input_length,
)); // max_new_tokens,
// ));
} }
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, input_length, max_new_tokens))
@ -408,48 +417,137 @@ async fn round_robin_task(
} }
/// Start tokenization workers /// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) { fn tokenizer_worker(
tokenizer: Tokenizer,
config: Option<Config>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
// Loop over requests // Loop over requests
let is_multimodal = {
let vocab = tokenizer.get_vocab(true);
vocab.contains_key("<image>")
};
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal)) .send(prepare_input(inputs, truncate, &tokenizer, &config))
.unwrap_or(()) .unwrap_or(())
}) })
} }
} }
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
match mimetype {
"image/png" => Some(ImageFormat::Png),
"image/jpeg" => Some(ImageFormat::Jpeg),
"image/jpg" => Some(ImageFormat::Jpeg),
"image/gif" => Some(ImageFormat::Gif),
"image/webp" => Some(ImageFormat::WebP),
"image/tiff" => Some(ImageFormat::Tiff),
// "image/pnm"=>Some(ImageFormat::Pnm),
// "image/tga"=>Some(ImageFormat::Tga),
// "image/dds"=>Some(ImageFormat::Dds),
// "image/bmp"=>Some(ImageFormat::Bmp),
// "image/ico"=>Some(ImageFormat::Ico),
// "image/x-exr"=>Some(ImageFormat::OpenExr),
_ => None,
}
}
fn format_to_mimetype(format: ImageFormat) -> String {
match format {
ImageFormat::Png => "image/png",
ImageFormat::Jpeg => "image/jpeg",
ImageFormat::Gif => "image/gif",
ImageFormat::WebP => "image/webp",
ImageFormat::Tiff => "image/tiff",
_ => "application/octet-stream",
}
.to_string()
}
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?;
let format = image::guess_format(&data)?;
// TODO Remove this clone
let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?;
let mimetype = format_to_mimetype(format);
let encoded = STANDARD.encode(data);
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
Ok((data_uri, height, width))
} else if input.starts_with("![](data:") {
// Remove ![](....)
let content = &input["![](data:".len()..input.len() - 1];
let tokens: Vec<_> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidImageContent(content.to_string()));
}
let mimetype = tokens[0];
let content = tokens[1];
if !content.starts_with("base64,") {
return Err(ValidationError::InvalidImageContent(content.to_string()));
}
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
let img = if let Some(format) = format_from_mimetype(mimetype) {
ImageReader::with_format(Cursor::new(data), format).decode()?
} else {
ImageReader::new(Cursor::new(data))
.with_guessed_format()
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
.decode()?
};
let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?;
Ok((input.to_string(), height, width))
} else {
Err(ValidationError::InvalidImageContent(input.to_string()))
}
}
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
mut inputs: String, mut inputs: String,
truncate: Option<usize>, _truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
is_multimodal: bool, config: &Option<Config>,
) -> Result<(tokenizers::Encoding, String), ValidationError> { ) -> Result<(tokenizers::Encoding, String), ValidationError> {
let simplified_query = if is_multimodal {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
RE.replace_all(&inputs, "<image>").into() let tokenizer_query = match config {
} else { Some(Config::LlavaNext(config)) => {
inputs.clone() let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
_ => inputs.clone(),
}; };
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(simplified_query, true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// Optionally truncate // Get the number of tokens in the input
if let Some(truncate) = truncate { let encoding = tokenizer
if truncate < encoding.len() && !is_multimodal { .encode(tokenizer_query, true)
encoding.truncate(truncate, 0, TruncationDirection::Left);
inputs = tokenizer
.decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
}
}
Ok((encoding, inputs)) Ok((encoding, inputs))
} }
@ -523,6 +621,16 @@ pub enum ValidationError {
Grammar, Grammar,
#[error("grammar is not valid: {0}")] #[error("grammar is not valid: {0}")]
InvalidGrammar(String), InvalidGrammar(String),
#[error("base64 encoding is invalid: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("invalid image: {0}")]
InvalidImage(#[from] image::ImageError),
#[error("invalid integer: {0}")]
InvalidInt(#[from] core::num::TryFromIntError),
#[error("invalid image content: {0}")]
InvalidImageContent(String),
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
} }
#[cfg(test)] #[cfg(test)]
@ -541,9 +649,11 @@ mod tests {
let max_total_tokens = 6; let max_total_tokens = 6;
let workers = 1; let workers = 1;
let disable_grammar_support = true; let disable_grammar_support = true;
let config = None;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
config,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -557,8 +667,9 @@ mod tests {
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxNewTokens(1, 10)) => (), // Err(ValidationError::MaxNewTokens(1, 10)) => (),
_ => panic!("Unexpected not max new tokens"), Ok((_s, 0, 10)) => (),
r => panic!("Unexpected not max new tokens: {r:?}"),
} }
} }
@ -572,9 +683,11 @@ mod tests {
let max_total_tokens = 6; let max_total_tokens = 6;
let disable_grammar_support = true; let disable_grammar_support = true;
let workers = 1; let workers = 1;
let config = None;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
config,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -603,9 +716,11 @@ mod tests {
let max_total_tokens = 6; let max_total_tokens = 6;
let workers = 1; let workers = 1;
let disable_grammar_support = true; let disable_grammar_support = true;
let config = None;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
config,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -639,9 +754,11 @@ mod tests {
let max_total_tokens = 106; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let disable_grammar_support = true; let disable_grammar_support = true;
let config = None;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
config,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -704,9 +821,11 @@ mod tests {
let max_total_tokens = 106; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let disable_grammar_support = true; let disable_grammar_support = true;
let config = None;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
config,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,

View File

@ -17,9 +17,6 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
install: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt

View File

@ -1,4 +1,4 @@
eetq_commit := 71adb5e191bb8290069a580abff0355d7b2dd5c9 eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
eetq: eetq:
# Clone eetq # Clone eetq

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69

View File

@ -1,10 +1,10 @@
vllm-cuda: vllm-cuda:
# Clone vllm # Clone vllm
pip install -U ninja packaging --no-cache-dir pip install -U ninja packaging --no-cache-dir
git clone https://github.com/vllm-project/vllm.git vllm git clone https://github.com/OlivierDehaene/vllm.git vllm
build-vllm-cuda: vllm-cuda build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda

1358
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "1.4.5" version = "2.0.1"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1" grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0" grpc-interceptor = "^0.15.0"
typer = "^0.6.1" typer = "^0.6.1"
accelerate = { version = "^0.28.0", optional = true } accelerate = { version = "^0.29.1", optional = true }
bitsandbytes = { version = "^0.43.0", optional = true } bitsandbytes = { version = "^0.43.0", optional = true }
safetensors = "^0.4" safetensors = "^0.4"
loguru = "^0.6.0" loguru = "^0.6.0"
@ -24,13 +24,13 @@ opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.15.0" tokenizers = "^0.19.1"
huggingface-hub = "^0.19.3" huggingface-hub = "^0.19.3"
transformers = "^4.38" transformers = "^4.40"
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }
peft = { version = "^0.9", optional = true } peft = { version = "^0.10", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"

View File

@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -27,20 +27,20 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -27,20 +27,20 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq" gptq = "gptq"
awq = "awq" awq = "awq"
eetq = "eetq" eetq = "eetq"
fp8 = "fp8"
class Dtype(str, Enum): class Dtype(str, Enum):

View File

@ -23,6 +23,10 @@ class ExceptionInterceptor(AsyncServerInterceptor):
method_name = method_name.split("/")[-1] method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.") logger.exception(f"Method {method_name} encountered an error.")
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -67,6 +67,7 @@ try:
FlashSantacoderSharded, FlashSantacoderSharded,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
@ -144,7 +145,7 @@ def get_model(
if speculate is not None: if speculate is not None:
if speculate > speculate_medusa: if speculate > speculate_medusa:
raise RuntimeError( raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match" f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
) )
else: else:
set_speculate(speculate) set_speculate(speculate)
@ -186,6 +187,14 @@ def get_model(
raise RuntimeError( raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}" f"Could not determine model type for {model_id} revision {revision}"
) )
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq"}:
logger.info(f"Auto selecting quantization method {method}")
quantize = method
else:
logger.info(f"Unknown quantization method {method}")
if model_type == "ssm": if model_type == "ssm":
return Mamba( return Mamba(
@ -571,6 +580,19 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if FLASH_ATTENTION:
return LlavaNext(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
if sharded: if sharded:
raise NotImplementedError("sharded is not supported for AutoModel") raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq": if quantize == "gptq":

View File

@ -43,7 +43,7 @@ class CacheManager:
] ]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange( self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32 0, num_blocks * self.block_size, dtype=torch.int64
).view(num_blocks, self.block_size) ).view(num_blocks, self.block_size)
def allocate( def allocate(
@ -55,9 +55,10 @@ class CacheManager:
): ):
# Get free blocks indices by finding values in mask that are not set to 0 # Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero() free_block_indices = self.free_block_mask.nonzero()
assert ( if blocks > len(free_block_indices):
len(free_block_indices) >= blocks raise RuntimeError(
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
)
# Slice by the number of required blocks # Slice by the number of required blocks
block_indices = free_block_indices[:blocks] block_indices = free_block_indices[:blocks]

View File

@ -0,0 +1,827 @@
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from text_generation_server.utils.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
# TODO Should we TP this ?
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=True,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
]
* 3,
dim=2,
)
query_states = query_states * self.scale
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_size)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ causal_attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class CLIPMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPPreTrainedModel(nn.Module):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = CLIPConfig
base_model_prefix = "clip"
supports_gradient_checkpointing = True
CLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
"""
CLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
"""
CLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
"""
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
CLIPEncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
)
return hidden_states
class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
r"""
Returns:
"""
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(
attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
dim=-1
),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
== self.eos_token_id
)
.int()
.argmax(dim=-1),
]
return last_hidden_state
class CLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.pre_layrnorm = nn.LayerNorm.load(
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
)
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
# pooled_output = last_hidden_state[:, 0, :]
# pooled_output = self.post_layernorm(pooled_output)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
)
class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
self.vision_model = CLIPVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, CLIPVisionModel
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model(
pixel_values=pixel_values,
)
class CLIPModel(nn.Module):
def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__()
text_config = config.text_config
vision_config = config.vision_config
self.projection_dim = config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = CLIPTextTransformer(text_config)
self.vision_model = CLIPVisionTransformer(vision_config)
self.visual_projection = nn.Linear(
self.vision_embed_dim, self.projection_dim, bias=False
)
self.text_projection = nn.Linear(
self.text_embed_dim, self.projection_dim, bias=False
)
self.logit_scale = nn.Parameter(
torch.tensor(self.config.logit_scale_init_value)
)
# Initialize weights and apply final processing
self.post_init()
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`CLIPTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs)
```"""
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
return text_features
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`CLIPVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs)
```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
return image_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
return logits_per_image, logits_per_text

View File

@ -23,10 +23,10 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -34,66 +34,107 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastLayerNorm,
) )
if IS_CUDA_SYSTEM:
import dropout_layer_norm
else:
dropout_layer_norm = None
class CohereConfig(PretrainedConfig):
def __init__( class CohereRotary(PositionRotaryEmbedding):
def forward(
self, self,
vocab_size=256000, query: torch.Tensor,
hidden_size=8192, key: torch.Tensor,
intermediate_size=22528, cos: torch.Tensor,
num_hidden_layers=40, sin: torch.Tensor,
num_attention_heads=64,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=5,
eos_token_id=255001,
pretraining_tp=1,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
logit_scale=1.0,
**kwargs,
): ):
self.vocab_size = vocab_size # Such controlflows may add some overhead.
self.max_position_embeddings = max_position_embeddings if IS_CUDA_SYSTEM:
self.hidden_size = hidden_size import rotary_emb
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility q1 = query[..., ::2]
if num_key_value_heads is None: q2 = query[..., 1::2]
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.logit_scale = logit_scale
super().__init__( k1 = key[..., ::2]
pad_token_id=pad_token_id, k2 = key[..., 1::2]
bos_token_id=bos_token_id,
eos_token_id=eos_token_id, rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
tie_word_embeddings=tie_word_embeddings, elif IS_ROCM_SYSTEM:
**kwargs, from vllm import pos_encoding_ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
) )
class CohereLayerNorm(nn.Module):
def __init__(self, prefix, weights, eps):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
self.weight = nn.Parameter(weight)
# Fake weights
self.ones = weight.new_ones(weight.shape[1])
self.eps = eps
def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
(
hidden_states,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.ones,
None,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
# Required to apply one weight matrix per head
hidden_states = hidden_states.view(
-1, self.weight.shape[0], self.weight.shape[1]
)
hidden_states = self.weight * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -154,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = CohereRotary.static(
config=config, config=config,
dim=self.head_size, dim=self.head_size,
base=config.rope_theta, base=config.rope_theta,
@ -175,6 +216,22 @@ class FlashCohereAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = CohereLayerNorm(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.k_norm = CohereLayerNorm(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.layer_norm_eps,
)
else:
self.q_norm = None
self.k_norm = None
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
@ -199,21 +256,28 @@ class FlashCohereAttention(torch.nn.Module):
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, key, value = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads, self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
], ],
dim=1, dim=1,
) )
if self.use_qk_norm:
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query = self.q_norm(query.contiguous())
key = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, key, cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -223,8 +287,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(
query, query,
torch.select(kv, dim=1, index=0), key,
torch.select(kv, dim=1, index=1), value,
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
@ -298,7 +362,7 @@ class FlashCohereLayer(nn.Module):
) )
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
@ -362,7 +426,7 @@ class FlashCohereModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps prefix="model.norm", weights=weights, eps=config.layer_norm_eps
) )

View File

@ -16,14 +16,13 @@
import torch import torch
import torch.distributed import torch.distributed
import numpy as np
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger from loguru import logger
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
@ -37,14 +36,6 @@ from text_generation_server.utils.layers import (
) )
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
HAS_MEGABLOCKS = True
try:
import stk
import megablocks.ops as ops
except ImportError:
logger.warning("Dbrx: megablocks is not installed")
HAS_MEGABLOCKS = False
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
def __init__( def __init__(
@ -531,18 +522,6 @@ def round_up(x: torch.Tensor, value: int):
class BlockSparseMoE(nn.Module): class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: DbrxConfig, weights): def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__() super().__init__()
self.moe_normalize_expert_weights = ( self.moe_normalize_expert_weights = (
@ -572,241 +551,40 @@ class BlockSparseMoE(nn.Module):
) )
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights) w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights) self.num_experts, self.ffn_dim, self.hidden_dim
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights) )
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
self.offsets = None self.num_experts, self.ffn_dim, self.hidden_dim
self.offsets_block_rows = 0 )
self.wv1 = torch.cat([w1, v1], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices def forward(self, x: torch.Tensor) -> torch.Tensor:
# so that we can pass it to radix sort. # router_logits: (num_tokens, n_experts)
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) router_logits = self.gate(x)
self.blocking = 128 out = fused_moe(
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(
gate_logits, self.top_k, self.moe_normalize_expert_weights
)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and v1,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.v1.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
x, x,
indices, self.wv1,
bin_ids, self.w2,
weights, router_logits,
bins,
padded_bins,
self.top_k, self.top_k,
self.quantize_scatter_num_bits, renormalize=self.moe_normalize_expert_weights,
).view(*input_shape) inplace=True,
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
weights,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
) )
# Mask not selected experts
weights.scatter_(1, not_selected_experts, 0)
# Re-normalize
if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out.view(*x.shape)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
class DenseMoE(nn.Module): class DenseMoE(nn.Module):

View File

@ -281,9 +281,8 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -337,27 +336,30 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(
layer_id, prefix=(
config, f"model.layers.{layer_id}"
weights, if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -368,7 +370,7 @@ class FlashLlamaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -376,8 +378,10 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -406,13 +410,19 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = FlashLlamaModel(config, weights) self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
@ -426,10 +436,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
@ -437,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -343,27 +342,24 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module): class MistralModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MistralLayer( MistralLayer(
layer_id, prefix=f"{prefix}.layers.{layer_id}",
config, config=config,
weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -374,7 +370,7 @@ class MistralModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -384,9 +380,8 @@ class MistralModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ):
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -410,18 +405,27 @@ class MistralModel(torch.nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = MistralModel(config, weights) self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
@ -453,8 +457,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,

View File

@ -24,6 +24,7 @@ import torch.distributed
import numpy as np import numpy as np
from torch import nn from torch import nn
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
@ -41,14 +42,6 @@ from text_generation_server.utils.layers import (
get_linear, get_linear,
) )
HAS_MEGABLOCKS = True
try:
import stk
import megablocks.ops as ops
except ImportError:
logger.warning("Mixtral: megablocks is not installed")
HAS_MEGABLOCKS = False
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
model_type = "mixtral" model_type = "mixtral"
@ -321,18 +314,6 @@ def round_up(x: torch.Tensor, value: int):
class BlockSparseMoE(nn.Module): class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: MixtralConfig, weights): def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__() super().__init__()
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
@ -357,236 +338,40 @@ class BlockSparseMoE(nn.Module):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights) w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) self.num_experts, self.ffn_dim, self.hidden_dim
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights) )
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.offsets = None self.num_experts, self.ffn_dim, self.hidden_dim
self.offsets_block_rows = 0 )
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices def forward(self, x: torch.Tensor) -> torch.Tensor:
# so that we can pass it to radix sort. # router_logits: (num_tokens, n_experts)
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) router_logits = self.gate(x)
self.blocking = 128 out = fused_moe(
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(gate_logits, self.top_k)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
x, x,
indices, self.w13,
bin_ids, self.w2,
weights, router_logits,
bins,
padded_bins,
self.top_k, self.top_k,
self.quantize_scatter_num_bits, renormalize=True,
).view(*input_shape) inplace=True,
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
) )
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out.view(*x.shape)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
class DenseMoE(nn.Module): class DenseMoE(nn.Module):
@ -679,9 +464,9 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
@ -740,16 +525,20 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MixtralLayer( MixtralLayer(
"model" if not prefix else f"{prefix}.model",
layer_id, layer_id,
config, config,
weights, weights,
@ -758,7 +547,9 @@ class MixtralModel(torch.nn.Module):
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
) )
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
@ -808,13 +599,13 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = MixtralModel(config, weights) self.model = MixtralModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window

View File

@ -0,0 +1,302 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -106,6 +106,19 @@ class FlashCausalLMBatch(Batch):
max_tokens=self.blocks * BLOCK_SIZE, max_tokens=self.blocks * BLOCK_SIZE,
) )
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer):
batch_inputs = []
max_truncation = 0
for r in requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
return batch_tokenized_inputs
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
@ -114,16 +127,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_inputs = [] batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = [] position_ids = []
speculative_ids = [] speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
@ -165,6 +169,11 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -690,7 +699,7 @@ class FlashCausalLM(Model):
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = ( block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device) torch.arange(max_bt, dtype=torch.int32, device=self.device)
@ -805,7 +814,7 @@ class FlashCausalLM(Model):
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
@ -865,22 +874,14 @@ class FlashCausalLM(Model):
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if bs == 3: if sorted_padded_bs:
padded_bs = 4 # Get associated cuda graph
elif 3 < bs <= 8: cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
padded_bs = 8 else:
elif bs > 8: cuda_graph = None
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph if cu_seqlen_prefill is not None or cuda_graph is None:
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,

View File

@ -3,12 +3,11 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers import AutoTokenizer from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM, FlashCohereForCausalLM,
CohereConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -32,7 +31,7 @@ class FlashCohere(FlashCausalLM):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashCohere is only available on GPU") raise NotImplementedError("FlashCohere is only available on GPU")
@ -46,7 +45,7 @@ class FlashCohere(FlashCausalLM):
from_slow=False, from_slow=False,
) )
config = CohereConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize

View File

@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,

View File

@ -6,8 +6,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, AutoTokenizer from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@ -65,19 +64,21 @@ class FlashMistralBatch(FlashCausalLMBatch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window, sliding_window_blocks = get_sliding_windows() sliding_window, sliding_window_blocks = get_sliding_windows()
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = [] position_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
needed_blocks_slots = [] needed_blocks_slots = []
@ -301,14 +302,15 @@ class FlashMistralBatch(FlashCausalLMBatch):
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
config_cls,
model_cls, model_cls,
model_id: str, model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -317,16 +319,7 @@ class BaseFlashMistral(FlashCausalLM):
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
try: tokenizer = tokenizer_class.from_pretrained(
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
@ -341,10 +334,12 @@ class BaseFlashMistral(FlashCausalLM):
config.use_medusa = use_medusa config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if getattr(config, "sliding_window", None) is not None:
set_sliding_window( set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
) )
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -353,17 +348,19 @@ class BaseFlashMistral(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = model_cls(config, weights) prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {} self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=num_layers,
num_kv_heads=model.model.num_key_value_heads, num_kv_heads=num_kv_heads,
head_size=model.model.head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
@ -371,6 +368,16 @@ class BaseFlashMistral(FlashCausalLM):
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
) )
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
def max_past(self) -> int:
return self.model.max_past
@property @property
def batch_type(self) -> Type[FlashMistralBatch]: def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch return FlashMistralBatch
@ -378,7 +385,7 @@ class BaseFlashMistral(FlashCausalLM):
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = ( block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device) torch.arange(max_bt, dtype=torch.int32, device=self.device)
@ -485,11 +492,11 @@ class BaseFlashMistral(FlashCausalLM):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.model.max_past is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.model.max_past, max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs padded_bs = bs

View File

@ -1,4 +1,5 @@
import torch import torch
import torch
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -20,29 +21,13 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.vlm_causal_lm import split
import re import re
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string):
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append(string[cursor:start])
parts.append(pattern.group(1))
cursor = pattern.end()
if cursor != len(string):
parts.append(string[cursor:])
return parts
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -93,10 +78,21 @@ class IdeficsCausalLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
raise NotImplementedError
@classmethod
def from_pb_processor(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack processor: ProcessorMixin, # Hack
config,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "IdeficsCausalLMBatch": ) -> "IdeficsCausalLMBatch":
@ -127,10 +123,14 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
# TODO Check impact on idefics
prompts = [] prompts = []
for inp in inputs: for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL # Each input is encoded into a list, where each element of this input list is either a string or a URL
prompts.append(split(inp)) prompt = []
for chunk in split(inp):
prompt.append(chunk["content"])
prompts.append(prompt)
# The processor replaces the call to tokenizer, and # The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL # a/ takes care of fetching images from the URL
@ -141,7 +141,8 @@ class IdeficsCausalLMBatch(Batch):
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token # TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device) ).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
@ -156,7 +157,7 @@ class IdeficsCausalLMBatch(Batch):
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs["pixel_values"] pixel_values = tokenized_inputs.get("pixel_values", None)
image_hidden_states = None image_hidden_states = None
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = input_ids.new_zeros( attention_mask = input_ids.new_zeros(
@ -165,11 +166,14 @@ class IdeficsCausalLMBatch(Batch):
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask # Do the same for image_attention_mask
if pixel_values is None:
image_attention_mask = None
else:
image_attention_mask = input_ids.new_zeros( image_attention_mask = input_ids.new_zeros(
( (
pb.size, pb.size,
max_input_length + padding_right_offset, max_input_length + padding_right_offset,
tokenized_inputs["pixel_values"].size(1), pixel_values.size(1),
) )
) )
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
@ -677,6 +681,9 @@ class IdeficsCausalLM(Model):
start = time.time_ns() start = time.time_ns()
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.image_attention_mask is None:
image_attention_mask = None
else:
if batch.input_ids.size(1) == 1: if batch.input_ids.size(1) == 1:
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension

View File

@ -0,0 +1,36 @@
import torch
from typing import Optional
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -0,0 +1,329 @@
import re
import torch
import math
from PIL import Image
from io import BytesIO
import base64
from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
tracer = trace.get_tracer(__name__)
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string) -> List[Dict[str, str]]:
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append({"type": "text", "content": string[cursor:start]})
parts.append({"type": "image", "content": pattern.group(1)})
cursor = pattern.end()
if cursor != len(string):
parts.append({"type": "text", "content": string[cursor:]})
return parts
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
height_of_patch = math.ceil(height / width * npatches)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
def load_data_uri(image_uri: str) -> Image.Image:
image_uri = image_uri.split(",")[-1]
content = base64.b64decode(image_uri)
image = Image.open(BytesIO(content))
return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.image_sizes = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.image_sizes = None
return batch
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.pixel_values = None
batch.image_sizes = None
return batch
class VlmCausalLM(BaseFlashMistral):
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits

View File

@ -13,6 +13,7 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
@ -78,13 +79,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
except ImportError: except ImportError:
pass pass
if ( if self.model.batch_type in {
self.model.batch_type == IdeficsCausalLMBatch IdeficsCausalLMBatch,
): # Hack, i would rather use kwargs in the `from_pb` call VlmCausalLMBatch,
batch = self.model.batch_type.from_pb( }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
self.model.model.config,
self.model.dtype, self.model.dtype,
self.model.device, self.model.device,
) )
@ -100,13 +103,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns() start = time.time_ns()
if ( if self.model.batch_type in {
self.model.batch_type == IdeficsCausalLMBatch IdeficsCausalLMBatch,
): # Hack, i would rather use kwargs in the `from_pb` call VlmCausalLMBatch,
batch = self.model.batch_type.from_pb( }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
self.model.model.config,
self.model.dtype, self.model.dtype,
self.model.device, self.model.device,
) )

View File

@ -88,6 +88,9 @@ def attention(
out, out,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
None,
None,
None,
max_s, max_s,
max_s, max_s,
0.0, 0.0,

View File

@ -19,7 +19,6 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True HAS_AWQ = True
try: try:
@ -35,12 +34,6 @@ except Exception:
HAS_EXLLAMA = False HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
# V2 = False
# log_once(
# logger.warning,
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
# )
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False HAS_EXLLAMA = False
@ -174,6 +167,8 @@ class EETQLinear(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
device = weight.device device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu() weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False) weight, scale = quant_weights(weight, torch.int8, False)
@ -187,6 +182,48 @@ class EETQLinear(nn.Module):
return output return output
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
device = weight.device
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
return qweight, scale
class Fp8Linear(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.dtype = weight.dtype
self.qweight, self.scale = fp8_quantize(weight)
self.bias = bias if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
qinput, scale = fp8_quantize(input)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)
return output
class Linear8bitLt(nn.Module): class Linear8bitLt(nn.Module):
def __init__( def __init__(
self, self,
@ -298,6 +335,8 @@ def get_linear(weight, bias, quantize):
raise ImportError( raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ" "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
) )
elif quantize == "fp8":
linear = Fp8Linear(weight, bias)
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
warn_deprecate_bnb() warn_deprecate_bnb()
linear = Linear8bitLt( linear = Linear8bitLt(
@ -393,12 +432,12 @@ class ResBlock(torch.nn.Module):
class MedusaModel(torch.nn.Module): class MedusaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, medusa_config, weights):
super().__init__() super().__init__()
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[ [
MedusaHead(config, prefix=f"{i}", weights=weights) MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"]) for i in range(medusa_config["medusa_num_heads"])
] ]
) )
@ -408,12 +447,12 @@ class MedusaModel(torch.nn.Module):
class MedusaHead(torch.nn.Module): class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, medusa_config, prefix, weights):
super().__init__() super().__init__()
self.blocks = torch.nn.ModuleList( self.blocks = torch.nn.ModuleList(
[ [
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"]) for i in range(medusa_config["medusa_num_layers"])
] ]
) )
n = len(self.blocks) n = len(self.blocks)
@ -428,7 +467,7 @@ class MedusaHead(torch.nn.Module):
return x return x
class SpeculativeHead(nn.Module): class MedusaHeadV1(nn.Module):
def __init__(self, lm_head, medusa): def __init__(self, lm_head, medusa):
super().__init__() super().__init__()
self.lm_head = lm_head self.lm_head = lm_head
@ -436,38 +475,156 @@ class SpeculativeHead(nn.Module):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path from pathlib import Path
from safetensors import safe_open from safetensors import safe_open
import json import json
use_medusa = config.use_medusa
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) medusa_config = json.load(f)
routing = weights.routing routing = weights.routing
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():
if k in routing: if k in routing and routing[k] != filename:
raise RuntimeError( raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}" f"Key {k} was found in multiple files: {filename} and {routing[k]}"
) )
weights.routing[k] = filename routing[k] = filename
medusa = MedusaModel(config, weights) medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MedusaHeadV1(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
speculative_logits = self.medusa(input)
return logits, speculative_logits
class MedusaHeadV2(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
from pathlib import Path
from safetensors import safe_open
import json
use_medusa = config.use_medusa
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"]
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
dim=0,
weights=weights,
bias=True,
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
self.act = torch.nn.SiLU()
self.lm_head = TensorParallelHead.load(config, prefix, weights)
def forward(self, x):
# If we have too many tokens, we skip speculative logits
if x.shape[0] > 128:
logits = self.lm_head(x)
return logits, None
size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size
x_block = x[:, start:stop]
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
)
# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res
# Gather medusa heads
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
# Compute lm head on x + medusa residual x
logits = self.lm_head(stacked_x)
# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(
logits, [1, self.n_medusa_heads], dim=-2
)
# Squeeze added dimension
logits = logits.squeeze(-2)
return logits, speculative_logits
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
use_medusa = config.use_medusa
if use_medusa:
lm_head = None
try:
medusa = MedusaHeadV1.load(config, prefix, weights)
except:
medusa = MedusaHeadV2(config, prefix, weights)
else: else:
lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None medusa = None
return SpeculativeHead(lm_head, medusa) return SpeculativeHead(lm_head, medusa)
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input) if self.medusa is not None:
speculative_logits = self.medusa(input) if self.medusa is not None else None return self.medusa(input)
return logits, speculative_logits
assert self.head is not None
logits = self.head(input)
return logits, None
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):

View File

@ -1,8 +1,6 @@
import torch import torch
# vllm imports from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from vllm import cache_ops
from vllm import attention_ops
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -14,7 +12,18 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if IS_CUDA_SYSTEM:
from vllm._C import cache_ops
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif IS_ROCM_SYSTEM:
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else:
raise ValueError("vllm is not supported on your system")
def attention( def attention(
@ -54,8 +63,29 @@ def attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
if IS_CUDA_SYSTEM:
from vllm._C import ops
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v1( attention_ops.paged_attention_v1(
out, out,
query, query,
@ -69,6 +99,9 @@ def attention(
max_s, max_s,
None, None,
) )
else:
raise ValueError("vllm is not supported on your system")
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
@ -83,6 +116,31 @@ def attention(
device=out.device, device=out.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if IS_CUDA_SYSTEM:
from vllm._C import ops
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v2( attention_ops.paged_attention_v2(
out, out,
exp_sums, exp_sums,
@ -99,3 +157,5 @@ def attention(
max_s, max_s,
None, None,
) )
else:
raise ValueError("vllm is not supported on your system")

5
tgi-entrypoint.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
text-generation-launcher $@