mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge branch 'huggingface:main' into main
This commit is contained in:
commit
3116fb5113
5
.github/workflows/autodocs.yml
vendored
5
.github/workflows/autodocs.yml
vendored
@ -13,7 +13,10 @@ jobs:
|
||||
|
||||
- name: Install Launcher
|
||||
id: install-launcher
|
||||
run: cargo install --git https://github.com/${{ github.repository }} --branch ${{ github.head_ref }} text-generation-launcher
|
||||
env:
|
||||
REF: ${{ github.head_ref }}
|
||||
REPO: ${{ github.repository }}
|
||||
run: cargo install --git "https://github.com/$REPO" --branch "$REF" text-generation-launcher
|
||||
|
||||
- name: Check launcher Docs are up-to-date
|
||||
run: |
|
||||
|
913
Cargo.lock
generated
913
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
@ -9,13 +9,19 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "1.4.5"
|
||||
version = "2.0.1"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
||||
[workspace.dependencies]
|
||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
incremental = true
|
||||
lto = "off"
|
||||
lto = "fat"
|
||||
opt-level = 3
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
17
Dockerfile
17
Dockerfile
@ -85,7 +85,7 @@ FROM pytorch-install as kernel-builder
|
||||
ARG MAX_JOBS=8
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
ninja-build \
|
||||
ninja-build cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Build Flash Attention CUDA kernels
|
||||
@ -160,11 +160,6 @@ WORKDIR /usr/src
|
||||
COPY server/Makefile-selective-scan Makefile
|
||||
RUN make build-all
|
||||
|
||||
# Build megablocks
|
||||
FROM kernel-builder as megablocks-builder
|
||||
|
||||
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
|
||||
|
||||
@ -186,8 +181,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy conda with PyTorch and Megablocks installed
|
||||
COPY --from=megablocks-builder /opt/conda /opt/conda
|
||||
# Copy conda with PyTorch installed
|
||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
||||
|
||||
# Copy build artifacts from flash attention builder
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
@ -215,7 +210,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
# Install vllm/flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
@ -250,5 +245,7 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -76,7 +76,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
|
||||
model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
@ -90,7 +90,7 @@ curl 127.0.0.1:8080/generate_stream \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -120,7 +120,7 @@ model=meta-llama/Llama-2-7b-chat-hf
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
@ -23,9 +23,9 @@ serde_json = "1.0"
|
||||
tabled = "0.14.0"
|
||||
text-generation-client = { path = "../router/client" }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { version = "0.14.0", features = ["http"] }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
hf-hub = "0.3.1"
|
||||
hf-hub = { workspace = true }
|
||||
|
@ -9,6 +9,11 @@ def flan_t5_xxl():
|
||||
return "google/flan-t5-xxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_7b():
|
||||
return "meta-llama/Llama-2-7b-chat-hf"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_model():
|
||||
return "fake/model"
|
||||
@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl):
|
||||
return f"{base_url}/{flan_t5_xxl}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_7b_url(base_url, llama_7b):
|
||||
return f"{base_url}/{llama_7b}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_url(base_url, fake_model):
|
||||
return f"{base_url}/{fake_model}"
|
||||
|
@ -5,24 +5,24 @@ from text_generation.errors import NotFoundError, ValidationError
|
||||
from text_generation.types import FinishReason, InputToken
|
||||
|
||||
|
||||
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
def test_generate(llama_7b_url, hf_headers):
|
||||
client = Client(llama_7b_url, hf_headers)
|
||||
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
|
||||
|
||||
assert response.generated_text == ""
|
||||
assert response.generated_text == "_"
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.prefill) == 2
|
||||
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0].id == 3
|
||||
assert response.details.tokens[0].text == " "
|
||||
assert response.details.tokens[0].id == 29918
|
||||
assert response.details.tokens[0].text == "_"
|
||||
assert not response.details.tokens[0].special
|
||||
|
||||
|
||||
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
def test_generate_best_of(llama_7b_url, hf_headers):
|
||||
client = Client(llama_7b_url, hf_headers)
|
||||
response = client.generate(
|
||||
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
|
||||
)
|
||||
@ -39,14 +39,14 @@ def test_generate_not_found(fake_url, hf_headers):
|
||||
client.generate("test")
|
||||
|
||||
|
||||
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
def test_generate_validation_error(llama_7b_url, hf_headers):
|
||||
client = Client(llama_7b_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
|
||||
def test_generate_stream(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
def test_generate_stream(llama_7b_url, hf_headers):
|
||||
client = Client(llama_7b_url, hf_headers)
|
||||
responses = [
|
||||
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||
]
|
||||
@ -54,7 +54,7 @@ def test_generate_stream(flan_t5_xxl_url, hf_headers):
|
||||
assert len(responses) == 1
|
||||
response = responses[0]
|
||||
|
||||
assert response.generated_text == ""
|
||||
assert response.generated_text == "_"
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
|
||||
list(client.generate_stream("test"))
|
||||
|
||||
|
||||
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
def test_generate_stream_validation_error(llama_7b_url, hf_headers):
|
||||
client = Client(llama_7b_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
list(client.generate_stream("test", max_new_tokens=10_000))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
async def test_generate_async(llama_7b_url, hf_headers):
|
||||
client = AsyncClient(llama_7b_url, hf_headers)
|
||||
response = await client.generate(
|
||||
"test", max_new_tokens=1, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.generated_text == ""
|
||||
assert response.generated_text == "_"
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.prefill) == 2
|
||||
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
|
||||
assert response.details.prefill[1] == InputToken(
|
||||
id=1243, text="test", logprob=-10.96875
|
||||
)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0].id == 3
|
||||
assert response.details.tokens[0].text == " "
|
||||
assert response.details.tokens[0].id == 29918
|
||||
assert response.details.tokens[0].text == "_"
|
||||
assert not response.details.tokens[0].special
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
async def test_generate_async_best_of(llama_7b_url, hf_headers):
|
||||
client = AsyncClient(llama_7b_url, hf_headers)
|
||||
response = await client.generate(
|
||||
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
|
||||
)
|
||||
@ -112,15 +115,15 @@ async def test_generate_async_not_found(fake_url, hf_headers):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
async def test_generate_async_validation_error(llama_7b_url, hf_headers):
|
||||
client = AsyncClient(llama_7b_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
await client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
async def test_generate_stream_async(llama_7b_url, hf_headers):
|
||||
client = AsyncClient(llama_7b_url, hf_headers)
|
||||
responses = [
|
||||
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||
]
|
||||
@ -128,7 +131,7 @@ async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
|
||||
assert len(responses) == 1
|
||||
response = responses[0]
|
||||
|
||||
assert response.generated_text == ""
|
||||
assert response.generated_text == "_"
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
|
||||
client = AsyncClient(llama_7b_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||
pass
|
||||
|
@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class CompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
text: str
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
|
||||
usage: Any
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
# Completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[CompletionComplete]
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
# Model identifier
|
||||
model: str
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "1.4.5"
|
||||
"version": "2.0.1"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
@ -408,9 +408,14 @@
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Generated Text",
|
||||
"description": "Generated Chat Completion",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ChatCompletion"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||
}
|
||||
@ -492,11 +497,16 @@
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Generated Text",
|
||||
"description": "Generated Chat Completion",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||
"$ref": "#/components/schemas/Completion"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CompletionCompleteChunk"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -930,7 +940,7 @@
|
||||
"tool_prompt": {
|
||||
"type": "string",
|
||||
"description": "A prompt to be appended before the tools",
|
||||
"example": "\"Based on the conversation, please choose the most appropriate tool to use: \"",
|
||||
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
||||
"nullable": true
|
||||
},
|
||||
"tools": {
|
||||
@ -1071,7 +1081,10 @@
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "The prompt to generate completions for.",
|
||||
"example": "What is Deep Learning?"
|
||||
},
|
||||
@ -1234,17 +1247,17 @@
|
||||
"type": "object",
|
||||
"required": [
|
||||
"name",
|
||||
"parameters"
|
||||
"arguments"
|
||||
],
|
||||
"properties": {
|
||||
"arguments": {},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"parameters": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
"GenerateParameters": {
|
||||
@ -1260,7 +1273,7 @@
|
||||
},
|
||||
"decoder_input_details": {
|
||||
"type": "boolean",
|
||||
"default": "true"
|
||||
"default": "false"
|
||||
},
|
||||
"details": {
|
||||
"type": "boolean",
|
||||
@ -1285,6 +1298,7 @@
|
||||
"$ref": "#/components/schemas/GrammarType"
|
||||
}
|
||||
],
|
||||
"default": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"max_new_tokens": {
|
||||
@ -1478,6 +1492,7 @@
|
||||
"max_batch_total_tokens",
|
||||
"max_waiting_tokens",
|
||||
"validation_workers",
|
||||
"max_client_batch_size",
|
||||
"version"
|
||||
],
|
||||
"properties": {
|
||||
@ -1503,6 +1518,11 @@
|
||||
"example": "2",
|
||||
"minimum": 0
|
||||
},
|
||||
"max_client_batch_size": {
|
||||
"type": "integer",
|
||||
"example": "32",
|
||||
"minimum": 0
|
||||
},
|
||||
"max_concurrent_requests": {
|
||||
"type": "integer",
|
||||
"description": "Router Parameters",
|
||||
|
@ -60,12 +60,13 @@ Options:
|
||||
[env: QUANTIZE=]
|
||||
|
||||
Possible values:
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
|
||||
|
||||
```
|
||||
## SPECULATE
|
||||
@ -128,23 +129,29 @@ Options:
|
||||
[env: MAX_TOP_N_TOKENS=]
|
||||
[default: 5]
|
||||
|
||||
```
|
||||
## MAX_INPUT_TOKENS
|
||||
```shell
|
||||
--max-input-tokens <MAX_INPUT_TOKENS>
|
||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
|
||||
|
||||
[env: MAX_INPUT_TOKENS=]
|
||||
|
||||
```
|
||||
## MAX_INPUT_LENGTH
|
||||
```shell
|
||||
--max-input-length <MAX_INPUT_LENGTH>
|
||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle
|
||||
Legacy version of [`Args::max_input_tokens`]
|
||||
|
||||
[env: MAX_INPUT_LENGTH=]
|
||||
[default: 1024]
|
||||
|
||||
```
|
||||
## MAX_TOTAL_TOKENS
|
||||
```shell
|
||||
--max-total-tokens <MAX_TOTAL_TOKENS>
|
||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be
|
||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
|
||||
|
||||
[env: MAX_TOTAL_TOKENS=]
|
||||
[default: 2048]
|
||||
|
||||
```
|
||||
## WAITING_SERVED_RATIO
|
||||
@ -161,10 +168,9 @@ Options:
|
||||
## MAX_BATCH_PREFILL_TOKENS
|
||||
```shell
|
||||
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
|
||||
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent
|
||||
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room
|
||||
|
||||
[env: MAX_BATCH_PREFILL_TOKENS=]
|
||||
[default: 4096]
|
||||
|
||||
```
|
||||
## MAX_BATCH_TOTAL_TOKENS
|
||||
@ -209,10 +215,9 @@ Options:
|
||||
## CUDA_GRAPHS
|
||||
```shell
|
||||
--cuda-graphs <CUDA_GRAPHS>
|
||||
Specify the batch sizes to compute cuda graphs for. Use "0" to disable
|
||||
Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32"
|
||||
|
||||
[env: CUDA_GRAPHS=]
|
||||
[default: 1,2,4,8,16,32,64,96,128]
|
||||
|
||||
```
|
||||
## HOSTNAME
|
||||
@ -393,6 +398,15 @@ Options:
|
||||
-e, --env
|
||||
Display a lot of information about your runtime environment
|
||||
|
||||
```
|
||||
## MAX_CLIENT_BATCH_SIZE
|
||||
```shell
|
||||
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
|
||||
Control the maximum number of inputs that a client can send in a single request
|
||||
|
||||
[env: MAX_CLIENT_BATCH_SIZE=]
|
||||
[default: 4]
|
||||
|
||||
```
|
||||
## HELP
|
||||
```shell
|
||||
|
@ -74,7 +74,7 @@ curl localhost:3000/generate \
|
||||
|
||||
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||
|
||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
|
||||
### Constrain with Pydantic
|
||||
|
||||
|
@ -22,6 +22,8 @@ The following models are optimized and can be served with TGI, which uses custom
|
||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Phi](https://huggingface.co/microsoft/phi-2)
|
||||
- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal)
|
||||
- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal)
|
||||
|
||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||
|
||||
|
@ -9,6 +9,7 @@ import json
|
||||
import math
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
|
||||
from docker.errors import NotFound
|
||||
from typing import Optional, List, Dict
|
||||
@ -26,6 +27,7 @@ from text_generation.types import (
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionComplete,
|
||||
Completion,
|
||||
)
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
@ -69,17 +71,22 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
data = json.loads(data)
|
||||
if isinstance(data, Dict) and "choices" in data:
|
||||
choices = data["choices"]
|
||||
if (
|
||||
isinstance(choices, List)
|
||||
and len(choices) >= 1
|
||||
and "delta" in choices[0]
|
||||
):
|
||||
return ChatCompletionChunk(**data)
|
||||
if isinstance(choices, List) and len(choices) >= 1:
|
||||
if "delta" in choices[0]:
|
||||
return ChatCompletionChunk(**data)
|
||||
if "text" in choices[0]:
|
||||
return Completion(**data)
|
||||
return ChatComplete(**data)
|
||||
|
||||
if isinstance(data, Dict):
|
||||
return Response(**data)
|
||||
if isinstance(data, List):
|
||||
if (
|
||||
len(data) > 0
|
||||
and "object" in data[0]
|
||||
and data[0]["object"] == "text_completion"
|
||||
):
|
||||
return [Completion(**d) for d in data]
|
||||
return [Response(**d) for d in data]
|
||||
raise NotImplementedError
|
||||
|
||||
@ -161,6 +168,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
)
|
||||
)
|
||||
|
||||
def eq_completion(response: Completion, other: Completion) -> bool:
|
||||
return response.choices[0].text == other.choices[0].text
|
||||
|
||||
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||
return (
|
||||
response.choices[0].message.content == other.choices[0].message.content
|
||||
@ -184,6 +194,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
if not isinstance(snapshot_data, List):
|
||||
snapshot_data = [snapshot_data]
|
||||
|
||||
if isinstance(serialized_data[0], Completion):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
if isinstance(serialized_data[0], ChatComplete):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
@ -277,6 +292,8 @@ def launcher(event_loop):
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
@ -314,6 +331,12 @@ def launcher(event_loop):
|
||||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
@ -347,6 +370,8 @@ def launcher(event_loop):
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
@ -367,6 +392,12 @@ def launcher(event_loop):
|
||||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
|
@ -13,11 +13,11 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795556,
|
||||
"created": 1712874856,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 60,
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " PR for more information?"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "le Business Incubator is providing a workspace"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " severely flawed and often has a substandard"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "hd20220811-"
|
||||
}
|
||||
],
|
||||
"created": 1713284455,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"usage": {
|
||||
"completion_tokens": 36,
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 44
|
||||
}
|
||||
}
|
@ -0,0 +1,602 @@
|
||||
[
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "hd"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "aho"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "ima"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " Sarah"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " Yes"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " And"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "i"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": ","
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " what"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "s"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " Moh"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "m"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " Room"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "s"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " tired"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": ":"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " capital"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " of"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " She"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " scale"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " of"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " being"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
}
|
||||
]
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " PR for flake8"
|
||||
}
|
||||
],
|
||||
"created": 1713284454,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"usage": {
|
||||
"completion_tokens": 5,
|
||||
"prompt_tokens": 6,
|
||||
"total_tokens": 11
|
||||
}
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 6,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -10.5,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -12.140625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0654297,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1014,
|
||||
"logprob": -2.7460938,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 6032,
|
||||
"logprob": -1.359375,
|
||||
"special": false,
|
||||
"text": " purpose"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 456,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 1369,
|
||||
"logprob": -0.40063477,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request\nThe purpose of this test"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.00756073,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.20117188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 16114,
|
||||
"logprob": -1.2597656,
|
||||
"special": false,
|
||||
"text": "Once"
|
||||
},
|
||||
{
|
||||
"id": 3714,
|
||||
"logprob": -0.20825195,
|
||||
"special": false,
|
||||
"text": " upon"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.00178051,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 727,
|
||||
"logprob": -0.011955261,
|
||||
"special": false,
|
||||
"text": " time"
|
||||
},
|
||||
{
|
||||
"id": 28725,
|
||||
"logprob": -0.17541504,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 736,
|
||||
"logprob": -0.91308594,
|
||||
"special": false,
|
||||
"text": " there"
|
||||
},
|
||||
{
|
||||
"id": 403,
|
||||
"logprob": -0.058410645,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.009689331,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nOnce upon a time, there was a"
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 9,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
@ -14,7 +14,7 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 16017,
|
||||
"logprob": -0.30908203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " blue"
|
||||
},
|
||||
@ -26,39 +26,45 @@
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -0.28271484,
|
||||
"logprob": -0.4716797,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 15484,
|
||||
"logprob": -1.7929688,
|
||||
"id": 261,
|
||||
"logprob": -0.044677734,
|
||||
"special": false,
|
||||
"text": "appear"
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": -0.8935547,
|
||||
"id": 35622,
|
||||
"logprob": -0.79589844,
|
||||
"special": false,
|
||||
"text": "ed"
|
||||
"text": " cloud"
|
||||
},
|
||||
{
|
||||
"id": 281,
|
||||
"id": 263,
|
||||
"logprob": -1.2958984,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 305,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"id": 35622,
|
||||
"logprob": -1.1630859,
|
||||
"special": false,
|
||||
"text": " cloud"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 20495,
|
||||
"logprob": -0.32299805,
|
||||
"special": false,
|
||||
"text": " sky"
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
@ -66,7 +72,8 @@
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
|
||||
"generated_text": "Why is the sky blue?blue sky, clouds and clouds"
|
||||
}
|
||||
|
@ -11,13 +11,12 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
"location": "Brooklyn"
|
||||
},
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
@ -27,14 +26,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795556,
|
||||
"created": 1712782670,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
"completion_tokens": 37,
|
||||
"prompt_tokens": 524,
|
||||
"total_tokens": 561
|
||||
}
|
||||
}
|
||||
|
@ -11,13 +11,12 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
"location": "Brooklyn"
|
||||
},
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
@ -27,14 +26,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795557,
|
||||
"created": 1712787937,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
"completion_tokens": 37,
|
||||
"prompt_tokens": 524,
|
||||
"total_tokens": 561
|
||||
}
|
||||
}
|
||||
|
@ -11,12 +11,12 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY"
|
||||
}
|
||||
},
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
@ -26,14 +26,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795557,
|
||||
"created": 1712852394,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 187,
|
||||
"total_tokens": 208
|
||||
"completion_tokens": 48,
|
||||
"prompt_tokens": 320,
|
||||
"total_tokens": 368
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": null,
|
||||
"name": "notify_error"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1712852597,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"usage": {
|
||||
"completion_tokens": 39,
|
||||
"prompt_tokens": 496,
|
||||
"total_tokens": 535
|
||||
}
|
||||
}
|
@ -19,9 +19,9 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1710795499,
|
||||
"created": 1712788218,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native"
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
}
|
||||
|
42
integration-tests/models/test_chat_llama.py
Normal file
42
integration-tests/models/test_chat_llama.py
Normal file
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_chat_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_chat(flash_llama_chat_handle):
|
||||
await flash_llama_chat_handle.health(300)
|
||||
return flash_llama_chat_handle.client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
|
||||
response = await flash_llama_chat.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
109
integration-tests/models/test_completion_prompts.py
Normal file
109
integration-tests/models/test_completion_prompts.py
Normal file
@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
import requests
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from text_generation.types import (
|
||||
Completion,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_completion_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_completion(flash_llama_completion_handle):
|
||||
await flash_llama_completion_handle.health(300)
|
||||
return flash_llama_completion_handle.client
|
||||
|
||||
|
||||
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
|
||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||
|
||||
|
||||
def test_flash_llama_completion_single_prompt(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 5,
|
||||
"seed": 0,
|
||||
},
|
||||
headers=flash_llama_completion.headers,
|
||||
stream=False,
|
||||
)
|
||||
response = response.json()
|
||||
assert len(response["choices"]) == 1
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"prompt": ["Say", "this", "is", "a"],
|
||||
"max_tokens": 10,
|
||||
"seed": 0,
|
||||
},
|
||||
headers=flash_llama_completion.headers,
|
||||
stream=False,
|
||||
)
|
||||
response = response.json()
|
||||
assert len(response["choices"]) == 4
|
||||
|
||||
all_indexes = [choice["index"] for choice in response["choices"]]
|
||||
all_indexes.sort()
|
||||
assert all_indexes == [0, 1, 2, 3]
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
async def test_flash_llama_completion_many_prompts_stream(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"prompt": [
|
||||
"What color is the sky?",
|
||||
"Is water wet?",
|
||||
"What is the capital of France?",
|
||||
"def mai",
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"seed": 0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||||
|
||||
chunks = []
|
||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||
async with session.post(url, json=request) as response:
|
||||
# iterate over the stream
|
||||
async for chunk in response.content.iter_any():
|
||||
# remove "data:"
|
||||
chunk = chunk.decode().split("\n\n")
|
||||
# remove "data:" if present
|
||||
chunk = [c.replace("data:", "") for c in chunk]
|
||||
# remove empty strings
|
||||
chunk = [c for c in chunk if c]
|
||||
# parse json
|
||||
chunk = [json.loads(c) for c in chunk]
|
||||
|
||||
for c in chunk:
|
||||
chunks.append(Completion(**c))
|
||||
assert "choices" in c
|
||||
assert 0 <= c["choices"][0]["index"] <= 4
|
||||
|
||||
assert response.status == 200
|
||||
assert chunks == response_snapshot
|
@ -33,6 +33,9 @@ async def test_idefics(idefics, response_snapshot):
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text == " \nAssistant: A rooster stands"
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@ -48,6 +51,9 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert (
|
||||
generated_texts[0] == " \nAssistant: A rooster stands"
|
||||
), f"{response.generated_text}"
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
|
84
integration-tests/models/test_llava_next.py
Normal file
84
integration-tests/models/test_llava_next.py
Normal file
@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llava_next_handle(launcher):
|
||||
with launcher(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
num_shard=4,
|
||||
max_input_length=4000,
|
||||
max_total_tokens=4096,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llava_next(flash_llava_next_handle):
|
||||
await flash_llava_next_handle.health(300)
|
||||
return flash_llava_next_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
response = await flash_llava_next.generate(
|
||||
f"User:Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
assert (
|
||||
response.generated_text == "\n\nOnce upon a time, there was a"
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||
response = await flash_llava_next.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 6
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_load(
|
||||
flash_llava_next, generate_load, response_snapshot
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_llava_next,
|
||||
f"User:Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
assert generated_texts[0] == "\n\nOnce upon a time, there was a"
|
||||
assert len(generated_texts) == 4
|
||||
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
@ -45,7 +45,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 9
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def t5_sharded_handle(launcher):
|
||||
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
|
||||
with launcher("google/flan-t5-xxl", num_shard=4) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -71,34 +71,7 @@ tools = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_no_tools(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 20
|
||||
assert count == 38
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=8,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.content == None
|
||||
assert responses.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": None,
|
||||
"name": "notify_error",
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-integration-tests"
|
||||
version = "1.4.5"
|
||||
version = "2.0.1"
|
||||
description = "Text Generation Inference integration tests"
|
||||
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||
|
||||
|
@ -9,8 +9,10 @@ homepage.workspace = true
|
||||
[dependencies]
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
ctrlc = { version = "3.4.1", features = ["termination"] }
|
||||
hf-hub = "0.3.2"
|
||||
nix = { version = "0.28.0", features = ["signal"] }
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
once_cell = "1.19.0"
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
|
@ -1,4 +1,5 @@
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use serde::Deserialize;
|
||||
@ -19,17 +20,23 @@ use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod env_runtime;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
max_seq_len: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||
/// https://hf.co/models?search=awq.
|
||||
/// <https://hf.co/models?search=awq>.
|
||||
/// Should replace GPTQ models wherever possible because of the better latency
|
||||
Awq,
|
||||
/// 8 bit quantization, doesn't require specific model.
|
||||
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
||||
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
Eetq,
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
|
||||
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
||||
/// triton kernel (wider support) when it's not.
|
||||
/// AWQ has faster kernels.
|
||||
@ -47,6 +54,11 @@ enum Quantization {
|
||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||
/// perplexity performance for you model
|
||||
BitsandbytesFP4,
|
||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||
/// This dtype has native ops should be the fastest if available.
|
||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||
/// multiplication limitations.
|
||||
Fp8,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
@ -73,6 +85,9 @@ impl std::fmt::Display for Quantization {
|
||||
Quantization::Eetq => {
|
||||
write!(f, "eetq")
|
||||
}
|
||||
Quantization::Fp8 => {
|
||||
write!(f, "fp8")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -206,8 +221,13 @@ struct Args {
|
||||
/// for users. The larger this value, the longer prompt users can send which
|
||||
/// can impact the overall memory required to handle the load.
|
||||
/// Please note that some models have a finite range of sequence they can handle.
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
/// Default to min(max_position_embeddings - 1, 4095)
|
||||
#[clap(long, env)]
|
||||
max_input_tokens: Option<usize>,
|
||||
|
||||
/// Legacy version of [`Args::max_input_tokens`].
|
||||
#[clap(long, env)]
|
||||
max_input_length: Option<usize>,
|
||||
|
||||
/// This is the most important value to set as it defines the "memory budget"
|
||||
/// of running clients requests.
|
||||
@ -217,8 +237,9 @@ struct Args {
|
||||
/// `1511` max_new_tokens.
|
||||
/// The larger this value, the larger amount each request will be in your RAM
|
||||
/// and the less effective batching can be.
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
/// Default to min(max_position_embeddings, 4096)
|
||||
#[clap(long, env)]
|
||||
max_total_tokens: Option<usize>,
|
||||
|
||||
/// This represents the ratio of waiting queries vs running queries where
|
||||
/// you want to start considering pausing the running queries to include the waiting
|
||||
@ -236,8 +257,9 @@ struct Args {
|
||||
/// Limits the number of tokens for the prefill operation.
|
||||
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||
/// to limit the number of requests that can be sent.
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
/// Default to `max_input_tokens + 50` to give a bit of room.
|
||||
#[clap(long, env)]
|
||||
max_batch_prefill_tokens: Option<u32>,
|
||||
|
||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||
/// of the available hardware.
|
||||
@ -286,13 +308,9 @@ struct Args {
|
||||
|
||||
/// Specify the batch sizes to compute cuda graphs for.
|
||||
/// Use "0" to disable.
|
||||
#[clap(
|
||||
long,
|
||||
env,
|
||||
value_delimiter = ',',
|
||||
default_value = "1,2,4,8,16,32,64,96,128"
|
||||
)]
|
||||
cuda_graphs: Vec<usize>,
|
||||
/// Default = "1,2,4,8,16,32"
|
||||
#[clap(long, env, value_delimiter = ',')]
|
||||
cuda_graphs: Option<Vec<usize>>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
@ -396,6 +414,10 @@ struct Args {
|
||||
/// Display a lot of information about your runtime environment
|
||||
#[clap(long, short, action)]
|
||||
env: bool,
|
||||
|
||||
/// Control the maximum number of inputs that a client can send in a single request
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -499,6 +521,9 @@ fn shard_manager(
|
||||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Remove LOG_LEVEL if present
|
||||
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||
|
||||
// Torch Distributed Env vars
|
||||
envs.push(("RANK".into(), rank.to_string().into()));
|
||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||
@ -586,6 +611,7 @@ fn shard_manager(
|
||||
tracing::info!("Starting shard");
|
||||
let mut p = match Command::new("text-generation-server")
|
||||
.args(shard_args)
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
@ -796,6 +822,14 @@ enum LauncherError {
|
||||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for LauncherError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for LauncherError {}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
@ -824,6 +858,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Remove LOG_LEVEL if present
|
||||
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||
|
||||
// Disable progress bar
|
||||
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
|
||||
|
||||
@ -858,6 +895,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Command::new("text-generation-server")
|
||||
.args(download_args)
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
@ -928,6 +966,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
shutdown_sender: mpsc::Sender<()>,
|
||||
@ -955,11 +994,7 @@ fn spawn_shards(
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let cuda_graphs: Vec<usize> = args
|
||||
.cuda_graphs
|
||||
.iter()
|
||||
.filter_map(|&c| if c > 0 { Some(c) } else { None })
|
||||
.collect();
|
||||
let cuda_graphs_clone = cuda_graphs.clone();
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
@ -981,7 +1016,7 @@ fn spawn_shards(
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
cuda_graphs,
|
||||
cuda_graphs_clone,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
@ -1037,6 +1072,9 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
||||
fn spawn_webserver(
|
||||
num_shard: usize,
|
||||
args: Args,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_prefill_tokens: u32,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Child, LauncherError> {
|
||||
@ -1044,6 +1082,8 @@ fn spawn_webserver(
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut router_args = vec![
|
||||
"--max-client-batch-size".to_string(),
|
||||
args.max_client_batch_size.to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
args.max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
@ -1052,12 +1092,12 @@ fn spawn_webserver(
|
||||
args.max_stop_sequences.to_string(),
|
||||
"--max-top-n-tokens".to_string(),
|
||||
args.max_top_n_tokens.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
args.max_input_length.to_string(),
|
||||
"--max-input-tokens".to_string(),
|
||||
max_input_tokens.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
args.max_total_tokens.to_string(),
|
||||
max_total_tokens.to_string(),
|
||||
"--max-batch-prefill-tokens".to_string(),
|
||||
args.max_batch_prefill_tokens.to_string(),
|
||||
max_batch_prefill_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
@ -1209,16 +1249,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
||||
}
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
match Command::new("ldconfig").spawn() {
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"Unable to refresh ldconfig cache. Skipping (useless in most cases). Details {:?}",
|
||||
err
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern match configuration
|
||||
let args: Args = Args::parse();
|
||||
|
||||
@ -1245,19 +1275,129 @@ fn main() -> Result<(), LauncherError> {
|
||||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
let api = Api::new()?;
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: Config = serde_json::from_str(&content)?;
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(max_position_embeddings)
|
||||
};
|
||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
||||
)));
|
||||
}
|
||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
||||
(None, None) => {
|
||||
let value = max_position_embeddings - 1;
|
||||
tracing::info!("Default `max_input_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_total_tokens = {
|
||||
match args.max_total_tokens {
|
||||
Some(max_total_tokens) => max_total_tokens,
|
||||
None => {
|
||||
let value = max_position_embeddings;
|
||||
tracing::info!("Default `max_total_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_batch_prefill_tokens = {
|
||||
match args.max_batch_prefill_tokens {
|
||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||
None => {
|
||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
||||
max_batch_size * max_input_tokens
|
||||
} else {
|
||||
// Adding some edge in order to account for potential block_size alignement
|
||||
// issue.
|
||||
max_input_tokens + 50
|
||||
} as u32;
|
||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Validate args
|
||||
if args.max_input_length >= args.max_total_tokens {
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_input_length
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
Some(
|
||||
Quantization::Bitsandbytes
|
||||
| Quantization::BitsandbytesNF4
|
||||
| Quantization::BitsandbytesFP4,
|
||||
),
|
||||
) => {
|
||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
_ => {
|
||||
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
||||
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
||||
cuda_graphs
|
||||
}
|
||||
};
|
||||
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
@ -1276,16 +1416,16 @@ fn main() -> Result<(), LauncherError> {
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||
max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, max_batch_total_tokens
|
||||
max_total_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -1332,6 +1472,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
cuda_graphs,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
@ -1346,11 +1487,19 @@ fn main() -> Result<(), LauncherError> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
|
||||
.map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
let mut webserver = spawn_webserver(
|
||||
num_shard,
|
||||
args,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
)
|
||||
.map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = Ok(());
|
||||
|
@ -21,7 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { version = "0.3.0", features = ["tokio"] }
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.21.1"
|
||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||
@ -33,7 +33,7 @@ reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { version = "0.15.1", features = ["http"] }
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||
@ -44,10 +44,12 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = "0.22.0"
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
@ -112,10 +112,15 @@ impl Client {
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(";
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
inputs,
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
|
158
router/src/config.rs
Normal file
158
router/src/config.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct LlavaNext {
|
||||
text_config: TextConfig,
|
||||
vision_config: VisionConfig,
|
||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
fn get_anyres_image_grid_shape(
|
||||
height: usize,
|
||||
width: usize,
|
||||
grid_pinpoints: &[(usize, usize)],
|
||||
patch_size: usize,
|
||||
) -> (usize, usize) {
|
||||
let (height, width) = select_best_resolution(height, width, grid_pinpoints);
|
||||
(height / patch_size, width / patch_size)
|
||||
}
|
||||
|
||||
/// Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
/// This is done by calculating the effective and wasted resolution for each possible resolution.
|
||||
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
|
||||
fn select_best_resolution(
|
||||
original_height: usize,
|
||||
original_width: usize,
|
||||
possible_resolutions: &[(usize, usize)],
|
||||
) -> (usize, usize) {
|
||||
let mut best_fit = None;
|
||||
let mut max_effective_resolution = 0;
|
||||
let mut min_wasted_resolution = f32::NEG_INFINITY;
|
||||
|
||||
for (height, width) in possible_resolutions {
|
||||
let wscale = *width as f32 / original_width as f32;
|
||||
let hscale = *height as f32 / original_height as f32;
|
||||
// f32 partial ord.
|
||||
let scale = if wscale > hscale { hscale } else { wscale };
|
||||
let downscaled_width = (*width as f32 * scale) as usize;
|
||||
let downscaled_height = (*height as f32 * scale) as usize;
|
||||
let effective_resolution = std::cmp::min(
|
||||
downscaled_width * downscaled_height,
|
||||
original_width * original_height,
|
||||
);
|
||||
let wasted_resolution = (width * height) - effective_resolution;
|
||||
|
||||
if effective_resolution > max_effective_resolution
|
||||
|| (effective_resolution == max_effective_resolution
|
||||
&& (wasted_resolution as f32) < min_wasted_resolution)
|
||||
{
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution as f32;
|
||||
best_fit = Some((*height, *width));
|
||||
}
|
||||
}
|
||||
|
||||
best_fit.unwrap_or((original_height, original_width))
|
||||
}
|
||||
|
||||
impl LlavaNext {
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let image_size = self.vision_config.image_size;
|
||||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
// Ceil
|
||||
let height_of_patch = (height * npatches + width - 1) / width;
|
||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
||||
// They are only added after width
|
||||
let newline_features = height_of_patch * num_patch_width;
|
||||
// The base patch covers the entire image
|
||||
let base_features = npatches.pow(2);
|
||||
unpadded_features + newline_features + base_features
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ClipVisionModel {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Config {
|
||||
LlavaNext(LlavaNext),
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Gemma,
|
||||
Cohere,
|
||||
Drbx,
|
||||
Falcon,
|
||||
Mixtral,
|
||||
Starcoder2,
|
||||
Qwen2,
|
||||
Opt,
|
||||
T5,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TextConfig {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct VisionConfig {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_llava_next_features() {
|
||||
let config = LlavaNext {
|
||||
text_config: TextConfig {},
|
||||
vision_config: VisionConfig {
|
||||
image_size: 336,
|
||||
patch_size: 14,
|
||||
},
|
||||
image_grid_pinpoints: vec![
|
||||
(336, 672),
|
||||
(672, 336),
|
||||
(672, 672),
|
||||
(1008, 336),
|
||||
(336, 1008),
|
||||
],
|
||||
};
|
||||
|
||||
let slots = config.get_number_of_features(640, 640);
|
||||
assert_eq!(slots, 2928);
|
||||
let slots = config.get_number_of_features(480, 640);
|
||||
assert_eq!(slots, 2340);
|
||||
let slots = config.get_number_of_features(899, 1024);
|
||||
assert_eq!(slots, 2732);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
}
|
||||
}
|
@ -1,12 +1,15 @@
|
||||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
||||
Message, PrefillToken, Queue, Token,
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use nohash_hasher::IntMap;
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@ -86,7 +89,18 @@ impl Infer {
|
||||
|
||||
let chat_template = tokenizer_config
|
||||
.chat_template
|
||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||
.and_then(|t| match t {
|
||||
ChatTemplateVersions::Single(template) => Some(template),
|
||||
ChatTemplateVersions::Multiple(templates) => templates
|
||||
.into_iter()
|
||||
.find(|t| t.name == "default")
|
||||
.map(|t| t.template),
|
||||
})
|
||||
.map(|t| {
|
||||
// .strip() is not supported in minijinja
|
||||
let t = t.replace(".strip()", " | trim");
|
||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||
});
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
@ -174,11 +188,15 @@ impl Infer {
|
||||
|
||||
/// Apply the chat template to the chat request
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(messages)
|
||||
.apply(messages, grammar_with_prompt)
|
||||
.map_err(|e| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||
tracing::error!("{e}");
|
||||
@ -311,6 +329,7 @@ struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
@ -318,6 +337,10 @@ impl ChatTemplate {
|
||||
let mut env = Box::new(Environment::new());
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
@ -327,21 +350,159 @@ impl ChatTemplate {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||
fn apply(
|
||||
&self,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content = Some(format!(
|
||||
"{}\n---\n{}\n{}",
|
||||
last_message.content.as_deref().unwrap_or_default(),
|
||||
tool_prompt,
|
||||
tools
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: Option<ToolType>,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(Some(tools));
|
||||
}
|
||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
@ -757,6 +918,8 @@ pub enum InferError {
|
||||
IncompleteGeneration,
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
#[error("Tool error: {0}")]
|
||||
ToolError(String),
|
||||
}
|
||||
|
||||
impl InferError {
|
||||
@ -767,6 +930,7 @@ impl InferError {
|
||||
InferError::ValidationError(_) => "validation",
|
||||
InferError::IncompleteGeneration => "incomplete_generation",
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
InferError::ToolError(_) => "tool_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -838,6 +1002,7 @@ mod tests {
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
@ -913,6 +1078,7 @@ mod tests {
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||
@ -987,6 +1153,7 @@ mod tests {
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
@ -1045,6 +1212,7 @@ mod tests {
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
@ -1099,93 +1267,100 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "_base",
|
||||
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some(""),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "blenderbot",
|
||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "blenderbot_small",
|
||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "bloom",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "gpt_neox",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "gpt2",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "llama",
|
||||
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
|
||||
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "whisper",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||
}
|
||||
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
@ -1211,7 +1386,8 @@ mod tests {
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>")
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
@ -1236,8 +1412,9 @@ mod tests {
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>"
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
||||
@ -1247,6 +1424,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<bos>"),
|
||||
eos_token: Some("<eos>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
@ -1258,8 +1436,9 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
@ -1269,6 +1448,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
@ -1276,10 +1456,11 @@ mod tests {
|
||||
name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
@ -1292,6 +1473,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>",
|
||||
},
|
||||
@ -1303,6 +1485,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
@ -1315,6 +1498,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ",
|
||||
},
|
||||
@ -1326,6 +1510,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!",
|
||||
},
|
||||
@ -1337,6 +1522,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!",
|
||||
},
|
||||
@ -1348,6 +1534,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<|begin▁of▁sentence|>"),
|
||||
eos_token: Some("<|end▁of▁sentence|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n",
|
||||
},
|
||||
@ -1359,8 +1546,9 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>"
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "internlm/internlm2-chat-7b",
|
||||
@ -1370,6 +1558,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
@ -1381,6 +1570,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<|begin▁of▁sentence|>"),
|
||||
eos_token: Some("<|EOT|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n",
|
||||
},
|
||||
@ -1393,6 +1583,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<|endoftext|>"),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!",
|
||||
},
|
||||
@ -1404,6 +1595,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
@ -1415,6 +1607,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!",
|
||||
},
|
||||
@ -1426,6 +1619,7 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<|begin▁of▁sentence|>"),
|
||||
eos_token: Some("</EOT>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n",
|
||||
},
|
||||
@ -1441,9 +1635,10 @@ mod tests {
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub mod config;
|
||||
mod health;
|
||||
/// Text Generation Inference Webserver
|
||||
mod infer;
|
||||
@ -48,9 +49,22 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct ChatTemplate {
|
||||
name: String,
|
||||
template: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatTemplateVersions {
|
||||
Single(String),
|
||||
Multiple(Vec<ChatTemplate>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<String>,
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
pub completion_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
@ -65,7 +79,7 @@ impl HubTokenizerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||
@ -141,6 +155,8 @@ pub struct Info {
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
#[schema(example = "32")]
|
||||
pub max_client_batch_size: usize,
|
||||
/// Router Info
|
||||
#[schema(example = "0.5.0")]
|
||||
pub version: &'static str,
|
||||
@ -222,7 +238,7 @@ pub(crate) struct GenerateParameters {
|
||||
#[schema(default = "true")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
#[schema(default = "false")]
|
||||
pub decoder_input_details: bool,
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
@ -236,6 +252,7 @@ pub(crate) struct GenerateParameters {
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub grammar: Option<GrammarType>,
|
||||
}
|
||||
|
||||
@ -266,6 +283,34 @@ fn default_parameters() -> GenerateParameters {
|
||||
}
|
||||
}
|
||||
|
||||
mod prompt_serde {
|
||||
use serde::{self, Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(vec![s]),
|
||||
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
||||
"Empty array detected. Do not use an empty array for the prompt.",
|
||||
)),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
Value::String(s) => Ok(s.to_owned()),
|
||||
_ => Err(serde::de::Error::custom("Expected a string")),
|
||||
})
|
||||
.collect(),
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"Expected a string or an array of strings",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
/// UNUSED
|
||||
@ -275,7 +320,8 @@ pub struct CompletionRequest {
|
||||
|
||||
/// The prompt to generate completions for.
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub prompt: String,
|
||||
#[serde(deserialize_with = "prompt_serde::deserialize")]
|
||||
pub prompt: Vec<String>,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
@ -655,7 +701,7 @@ pub(crate) struct ChatRequest {
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
|
||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
@ -668,7 +714,7 @@ pub(crate) struct ChatRequest {
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
@ -713,26 +759,26 @@ mod deserialize_tool_choice {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||
pub struct Tools {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct FunctionsMap {
|
||||
#[serde(rename = "$functions")]
|
||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
ref_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct Properties {
|
||||
#[serde(serialize_with = "serialize_function")]
|
||||
function: Vec<FunctionRef>,
|
||||
@ -753,7 +799,8 @@ pub(crate) struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
#[serde(alias = "parameters")]
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
@ -765,12 +812,14 @@ pub(crate) struct Tool {
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Clone, Serialize, Deserialize, Default)]
|
||||
pub(crate) struct ChatTemplateInputs<'a> {
|
||||
messages: Vec<Message>,
|
||||
bos_token: Option<&'a str>,
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
tools: Option<&'a str>,
|
||||
tools_prompt: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
@ -977,7 +1026,10 @@ mod tests {
|
||||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||
|
||||
// check that we successfully parsed the tokens
|
||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||
assert_eq!(
|
||||
config.chat_template,
|
||||
Some(ChatTemplateVersions::Single("test".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
@ -1009,7 +1061,10 @@ mod tests {
|
||||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||
|
||||
// check that we successfully parsed the tokens
|
||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||
assert_eq!(
|
||||
config.chat_template,
|
||||
Some(ChatTemplateVersions::Single("test".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
|
@ -13,6 +13,7 @@ use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
@ -34,7 +35,7 @@ struct Args {
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
@ -77,6 +78,8 @@ struct Args {
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -89,7 +92,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
@ -111,19 +114,20 @@ async fn main() -> Result<(), RouterError> {
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_length >= max_total_tokens {
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
@ -191,15 +195,19 @@ async fn main() -> Result<(), RouterError> {
|
||||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer, model_info) = if local_model {
|
||||
let (tokenizer, model_info, config) = if local_model {
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| serde_json::from_str(c).ok());
|
||||
|
||||
(tokenizer, model_info)
|
||||
(tokenizer, model_info, config)
|
||||
} else if let Some(api) = api.clone() {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
@ -212,6 +220,19 @@ async fn main() -> Result<(), RouterError> {
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
|
||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
|
||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo {
|
||||
@ -221,7 +242,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
}
|
||||
});
|
||||
|
||||
(tokenizer, model_info)
|
||||
(tokenizer, model_info, config)
|
||||
} else {
|
||||
// No API and no local model
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
@ -229,6 +250,8 @@ async fn main() -> Result<(), RouterError> {
|
||||
));
|
||||
};
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||
tracing::info!("Using local tokenizer config from user specified path");
|
||||
@ -291,7 +314,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
tracing::info!("Warming up model");
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(
|
||||
max_input_length as u32,
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
@ -354,7 +377,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
@ -363,6 +386,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_batch_size,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
config,
|
||||
validation_workers,
|
||||
addr,
|
||||
cors_allow_origin,
|
||||
@ -372,6 +396,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
tokenizer_config,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
@ -381,12 +406,15 @@ async fn main() -> Result<(), RouterError> {
|
||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// STDOUT/STDERR layer
|
||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_ansi(ansi)
|
||||
.with_line_number(true);
|
||||
|
||||
let fmt_layer = match json_output {
|
||||
|
@ -190,16 +190,22 @@ impl State {
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if we have enough entries
|
||||
if let Some(min_size) = min_size {
|
||||
if self.entries.len() < min_size {
|
||||
tracing::debug!("Not enough entries");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
@ -218,6 +224,7 @@ impl State {
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -254,10 +261,12 @@ impl State {
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
@ -288,6 +297,7 @@ impl State {
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::config::Config;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
@ -14,7 +15,8 @@ use crate::{
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -22,20 +24,21 @@ use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{http, Json, Router};
|
||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
use tokio::signal;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
@ -161,10 +164,20 @@ async fn generate(
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||
}
|
||||
|
||||
async fn generate_internal(
|
||||
infer: Extension<Infer>,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
span: tracing::Span,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
@ -358,12 +371,13 @@ async fn generate_stream(
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
|
||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
@ -373,8 +387,8 @@ async fn generate_stream_internal(
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
span: tracing::Span,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
@ -547,7 +561,11 @@ async fn generate_stream_internal(
|
||||
path = "/v1/completions",
|
||||
request_body = CompletionRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||
(status = 200, description = "Generated Chat Completion",
|
||||
content(
|
||||
("application/json" = Completion),
|
||||
("text/event-stream" = CompletionCompleteChunk),
|
||||
)),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
@ -576,6 +594,7 @@ async fn completions(
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
@ -595,100 +614,299 @@ async fn completions(
|
||||
));
|
||||
}
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: req.prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
},
|
||||
};
|
||||
if req.prompt.len() > info.max_client_batch_size {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: format!(
|
||||
"Number of prompts exceeds the maximum allowed batch size of {}",
|
||||
info.max_client_batch_size
|
||||
),
|
||||
error_type: "batch size exceeded".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let generate_requests: Vec<GenerateRequest> = req
|
||||
.prompt
|
||||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut x_compute_type = None;
|
||||
let mut x_compute_characters = 0u32;
|
||||
let mut x_accel_buffering = None;
|
||||
|
||||
if stream {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
let mut response_streams = FuturesOrdered::new();
|
||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint =
|
||||
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
// Create a future for each generate_stream_internal call.
|
||||
let generate_future = async move {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
})
|
||||
.map_or_else(
|
||||
|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
},
|
||||
|data| data,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
})
|
||||
.map_or_else(|_e| Event::default(), |data| data)
|
||||
};
|
||||
|
||||
let (header_tx, header_rx) = oneshot::channel();
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (header_map, sse) = generate_stream_internal(
|
||||
infer_clone.clone(),
|
||||
compute_type_clone.clone(),
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// send and dont wait for response
|
||||
let _ = header_tx.send(header_map);
|
||||
|
||||
// pin an emit messages to the sse_tx
|
||||
let mut sse = Box::pin(sse);
|
||||
while let Some(event) = sse.next().await {
|
||||
if sse_tx.send(event).is_err() {
|
||||
tracing::error!("Failed to send event. Receiver dropped.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(header_rx, sse_rx)
|
||||
};
|
||||
response_streams.push_back(generate_future);
|
||||
}
|
||||
|
||||
let mut all_rxs = vec![];
|
||||
|
||||
while let Some((header_rx, sse_rx)) = response_streams.next().await {
|
||||
all_rxs.push(sse_rx);
|
||||
|
||||
// get the headers from the first response of each stream
|
||||
let headers = header_rx.await.map_err(|e| {
|
||||
tracing::error!("Failed to get headers: {:?}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to get headers".to_string(),
|
||||
error_type: "headers".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
if x_compute_type.is_none() {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string());
|
||||
|
||||
x_accel_buffering = headers
|
||||
.get("x-accel-buffering")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string());
|
||||
}
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(0);
|
||||
}
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(x_compute_type) = x_compute_type {
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
}
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
}
|
||||
|
||||
// now sink the sse streams into a single stream and remove the ones that are done
|
||||
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
|
||||
loop {
|
||||
let mut i = 0;
|
||||
while i < all_rxs.len() {
|
||||
let rx = &mut all_rxs[i];
|
||||
select! {
|
||||
Some(event) = rx.recv() => {
|
||||
yield event;
|
||||
}
|
||||
else => {
|
||||
all_rxs.remove(i);
|
||||
continue; // skip the increment to handle the next element at the same index
|
||||
}
|
||||
}
|
||||
i += 1; // only increment when no element was removed
|
||||
}
|
||||
|
||||
if all_rxs.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(
|
||||
Extension(infer),
|
||||
Extension(compute_type),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let details = generation.details.ok_or((
|
||||
// this should never happen but handle if details are missing unexpectedly
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "No details in generation".to_string(),
|
||||
error_type: "no details".to_string(),
|
||||
}),
|
||||
))?;
|
||||
let responses = FuturesUnordered::new();
|
||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
let response_future = async move {
|
||||
let result = generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await;
|
||||
result.map(|(headers, generation)| (index, headers, generation))
|
||||
};
|
||||
responses.push(response_future);
|
||||
}
|
||||
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||
|
||||
let mut prompt_tokens = 0u32;
|
||||
let mut completion_tokens = 0u32;
|
||||
let mut total_tokens = 0u32;
|
||||
|
||||
let mut x_compute_time = 0u32;
|
||||
let mut x_total_time = 0u32;
|
||||
let mut x_validation_time = 0u32;
|
||||
let mut x_queue_time = 0u32;
|
||||
let mut x_inference_time = 0u32;
|
||||
let mut x_time_per_token = 0u32;
|
||||
let mut x_prompt_tokens = 0u32;
|
||||
let mut x_generated_tokens = 0u32;
|
||||
|
||||
let choices = generate_responses
|
||||
.into_iter()
|
||||
.map(|(index, headers, Json(generation))| {
|
||||
let details = generation.details.ok_or((
|
||||
// this should never happen but handle if details are missing unexpectedly
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "No details in generation".to_string(),
|
||||
error_type: "no details".to_string(),
|
||||
}),
|
||||
))?;
|
||||
|
||||
if x_compute_type.is_none() {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string());
|
||||
}
|
||||
|
||||
// accumulate headers and usage from each response
|
||||
x_compute_time += headers
|
||||
.get("x-compute-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_total_time += headers
|
||||
.get("x-total-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_validation_time += headers
|
||||
.get("x-validation-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_queue_time += headers
|
||||
.get("x-queue-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_inference_time += headers
|
||||
.get("x-inference-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_time_per_token += headers
|
||||
.get("x-time-per-token")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_prompt_tokens += headers
|
||||
.get("x-prompt-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_generated_tokens += headers
|
||||
.get("x-generated-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
prompt_tokens += details.prefill.len() as u32;
|
||||
completion_tokens += details.generated_tokens;
|
||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
|
||||
Ok(CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|(status, Json(err))| (status, Json(err)))?;
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
@ -700,19 +918,30 @@ async fn completions(
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}],
|
||||
choices,
|
||||
usage: Usage {
|
||||
prompt_tokens: details.prefill.len() as u32,
|
||||
completion_tokens: details.generated_tokens,
|
||||
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
// headers similar to `generate` but aggregated
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(x_compute_type) = x_compute_type {
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
}
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-total-time", x_total_time.into());
|
||||
headers.insert("x-validation-time", x_validation_time.into());
|
||||
headers.insert("x-queue-time", x_queue_time.into());
|
||||
headers.insert("x-inference-time", x_inference_time.into());
|
||||
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
}
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
}
|
||||
@ -724,7 +953,11 @@ async fn completions(
|
||||
path = "/v1/chat/completions",
|
||||
request_body = ChatRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||
(status = 200, description = "Generated Chat Completion",
|
||||
content(
|
||||
("application/json" = ChatCompletion),
|
||||
("text/event-stream" = ChatCompletionChunk),
|
||||
)),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
@ -753,21 +986,32 @@ async fn chat_completions(
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let repetition_penalty = req
|
||||
.presence_penalty
|
||||
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||
.map(|x| x + 2.0);
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
let stop = req.stop.unwrap_or_default();
|
||||
let ChatRequest {
|
||||
logprobs,
|
||||
max_tokens,
|
||||
messages,
|
||||
presence_penalty,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
..
|
||||
} = req;
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let stop = stop.unwrap_or_default();
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
@ -781,60 +1025,28 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Tool choice not found in tool names".to_string(),
|
||||
error_type: "Tool not found".to_string(),
|
||||
}),
|
||||
)
|
||||
})?
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
let grammar_with_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
(func.name, func.parameters)
|
||||
})
|
||||
.collect();
|
||||
let typed_grammar = grammar_with_prompt
|
||||
.as_ref()
|
||||
.map(|(grammar, _)| grammar.clone());
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||
} else {
|
||||
None
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// build the request passing some parameters
|
||||
@ -858,7 +1070,7 @@ async fn chat_completions(
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar: tool_grammar.clone(),
|
||||
grammar: typed_grammar,
|
||||
},
|
||||
};
|
||||
|
||||
@ -912,17 +1124,14 @@ async fn chat_completions(
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span,
|
||||
)
|
||||
.await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(
|
||||
Extension(infer),
|
||||
Extension(compute_type),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await?;
|
||||
let (headers, Json(generation)) =
|
||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -941,27 +1150,28 @@ async fn chat_completions(
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: 0,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").map_or_else(
|
||||
|| {
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
},
|
||||
|f| Ok(f.clone()),
|
||||
)?,
|
||||
name: gen_text_value
|
||||
.get("function")
|
||||
.and_then(|f| f.get("_name"))
|
||||
.and_then(|name| name.as_str())
|
||||
.unwrap_or("default_function_name")
|
||||
.to_string(),
|
||||
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
||||
arguments: gen_text_value
|
||||
.get("function")
|
||||
.map(|f| {
|
||||
let mut f_cloned = f.clone();
|
||||
if let Value::Object(ref mut props) = f_cloned {
|
||||
props.remove("_name");
|
||||
}
|
||||
f_cloned
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
@ -1018,6 +1228,7 @@ async fn vertex_compatibility(
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
// check that theres at least one instance
|
||||
@ -1049,10 +1260,11 @@ async fn vertex_compatibility(
|
||||
};
|
||||
|
||||
async {
|
||||
generate(
|
||||
generate_internal(
|
||||
Extension(infer.clone()),
|
||||
Extension(compute_type.clone()),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
span.clone(),
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
@ -1154,6 +1366,7 @@ pub async fn run(
|
||||
max_batch_size: Option<usize>,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
@ -1163,6 +1376,7 @@ pub async fn run(
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -1236,6 +1450,7 @@ pub async fn run(
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
@ -1336,6 +1551,7 @@ pub async fn run(
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
validation_workers,
|
||||
max_client_batch_size,
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
@ -1535,6 +1751,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
};
|
||||
|
||||
(
|
||||
|
@ -1,15 +1,19 @@
|
||||
use crate::config::Config;
|
||||
/// Payload validation logic
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
// use tokenizers::TruncationDirection;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{io::Reader as ImageReader, ImageFormat};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
@ -34,6 +38,7 @@ impl Validation {
|
||||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
@ -50,12 +55,13 @@ impl Validation {
|
||||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let config_clone = config.clone();
|
||||
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||
senders.push(tokenizer_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone, tokenizer_receiver)
|
||||
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
|
||||
});
|
||||
}
|
||||
|
||||
@ -155,14 +161,17 @@ impl Validation {
|
||||
} else {
|
||||
return Err(ValidationError::UnsetMaxNewTokens);
|
||||
};
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
let mut input_length = truncate.unwrap_or(self.max_input_length);
|
||||
|
||||
// We don't have a tokenizer, therefore we have no idea how long is the query, let
|
||||
// them through and hope for the best.
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
return Err(ValidationError::MaxNewTokens(
|
||||
self.max_total_tokens - self.max_input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
// return Err(ValidationError::MaxNewTokens(
|
||||
// self.max_total_tokens - self.max_input_length,
|
||||
// max_new_tokens,
|
||||
// ));
|
||||
}
|
||||
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
@ -408,48 +417,137 @@ async fn round_robin_task(
|
||||
}
|
||||
|
||||
/// Start tokenization workers
|
||||
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
||||
fn tokenizer_worker(
|
||||
tokenizer: Tokenizer,
|
||||
config: Option<Config>,
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
let is_multimodal = {
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
vocab.contains_key("<image>")
|
||||
};
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, &config))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
||||
match mimetype {
|
||||
"image/png" => Some(ImageFormat::Png),
|
||||
"image/jpeg" => Some(ImageFormat::Jpeg),
|
||||
"image/jpg" => Some(ImageFormat::Jpeg),
|
||||
"image/gif" => Some(ImageFormat::Gif),
|
||||
"image/webp" => Some(ImageFormat::WebP),
|
||||
"image/tiff" => Some(ImageFormat::Tiff),
|
||||
// "image/pnm"=>Some(ImageFormat::Pnm),
|
||||
// "image/tga"=>Some(ImageFormat::Tga),
|
||||
// "image/dds"=>Some(ImageFormat::Dds),
|
||||
// "image/bmp"=>Some(ImageFormat::Bmp),
|
||||
// "image/ico"=>Some(ImageFormat::Ico),
|
||||
// "image/x-exr"=>Some(ImageFormat::OpenExr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
fn format_to_mimetype(format: ImageFormat) -> String {
|
||||
match format {
|
||||
ImageFormat::Png => "image/png",
|
||||
ImageFormat::Jpeg => "image/jpeg",
|
||||
ImageFormat::Gif => "image/gif",
|
||||
ImageFormat::WebP => "image/webp",
|
||||
ImageFormat::Tiff => "image/tiff",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||
if input.starts_with(" || input.starts_with(" {
|
||||
let url = &input["..input.len() - 1];
|
||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
||||
|
||||
let format = image::guess_format(&data)?;
|
||||
// TODO Remove this clone
|
||||
let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
let mimetype = format_to_mimetype(format);
|
||||
let encoded = STANDARD.encode(data);
|
||||
let data_uri = format!("");
|
||||
Ok((data_uri, height, width))
|
||||
} else if input.starts_with(" {
|
||||
// Remove 
|
||||
let content = &input["..input.len() - 1];
|
||||
let tokens: Vec<_> = content.split(';').collect();
|
||||
if tokens.len() != 2 {
|
||||
return Err(ValidationError::InvalidImageContent(content.to_string()));
|
||||
}
|
||||
let mimetype = tokens[0];
|
||||
let content = tokens[1];
|
||||
|
||||
if !content.starts_with("base64,") {
|
||||
return Err(ValidationError::InvalidImageContent(content.to_string()));
|
||||
}
|
||||
|
||||
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
||||
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
||||
ImageReader::with_format(Cursor::new(data), format).decode()?
|
||||
} else {
|
||||
ImageReader::new(Cursor::new(data))
|
||||
.with_guessed_format()
|
||||
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
||||
.decode()?
|
||||
};
|
||||
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
Ok((input.to_string(), height, width))
|
||||
} else {
|
||||
Err(ValidationError::InvalidImageContent(input.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
mut inputs: String,
|
||||
truncate: Option<usize>,
|
||||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
is_multimodal: bool,
|
||||
config: &Option<Config>,
|
||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||
let simplified_query = if is_multimodal {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
RE.replace_all(&inputs, "<image>").into()
|
||||
} else {
|
||||
inputs.clone()
|
||||
};
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(simplified_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
// Optionally truncate
|
||||
if let Some(truncate) = truncate {
|
||||
if truncate < encoding.len() && !is_multimodal {
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
inputs = tokenizer
|
||||
.decode(encoding.get_ids(), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let tokenizer_query = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
}
|
||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
||||
_ => inputs.clone(),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
let encoding = tokenizer
|
||||
.encode(tokenizer_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, inputs))
|
||||
}
|
||||
@ -523,6 +621,16 @@ pub enum ValidationError {
|
||||
Grammar,
|
||||
#[error("grammar is not valid: {0}")]
|
||||
InvalidGrammar(String),
|
||||
#[error("base64 encoding is invalid: {0}")]
|
||||
InvalidBase64(#[from] base64::DecodeError),
|
||||
#[error("invalid image: {0}")]
|
||||
InvalidImage(#[from] image::ImageError),
|
||||
#[error("invalid integer: {0}")]
|
||||
InvalidInt(#[from] core::num::TryFromIntError),
|
||||
#[error("invalid image content: {0}")]
|
||||
InvalidImageContent(String),
|
||||
#[error("Could not fetch image: {0}")]
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -541,9 +649,11 @@ mod tests {
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -557,8 +667,9 @@ mod tests {
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
Ok((_s, 0, 10)) => (),
|
||||
r => panic!("Unexpected not max new tokens: {r:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -572,9 +683,11 @@ mod tests {
|
||||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -603,9 +716,11 @@ mod tests {
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -639,9 +754,11 @@ mod tests {
|
||||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -704,9 +821,11 @@ mod tests {
|
||||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
@ -17,9 +17,6 @@ gen-server:
|
||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation_server/pb/__init__.py
|
||||
|
||||
install-megablocks:
|
||||
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||
|
||||
install: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements_cuda.txt
|
||||
|
@ -1,4 +1,4 @@
|
||||
eetq_commit := 71adb5e191bb8290069a580abff0355d7b2dd5c9
|
||||
eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
|
||||
|
||||
eetq:
|
||||
# Clone eetq
|
||||
|
@ -1,4 +1,4 @@
|
||||
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
vllm-cuda:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone https://github.com/vllm-project/vllm.git vllm
|
||||
git clone https://github.com/OlivierDehaene/vllm.git vllm
|
||||
|
||||
build-vllm-cuda: vllm-cuda
|
||||
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm-cuda: build-vllm-cuda
|
||||
|
1358
server/poetry.lock
generated
1358
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "1.4.5"
|
||||
version = "2.0.1"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
|
||||
grpcio-reflection = "^1.51.1"
|
||||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.28.0", optional = true }
|
||||
accelerate = { version = "^0.29.1", optional = true }
|
||||
bitsandbytes = { version = "^0.43.0", optional = true }
|
||||
safetensors = "^0.4"
|
||||
loguru = "^0.6.0"
|
||||
@ -24,13 +24,13 @@ opentelemetry-exporter-otlp = "^1.15.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.15.0"
|
||||
tokenizers = "^0.19.1"
|
||||
huggingface-hub = "^0.19.3"
|
||||
transformers = "^4.38"
|
||||
transformers = "^4.40"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
peft = { version = "^0.9", optional = true }
|
||||
peft = { version = "^0.10", optional = true }
|
||||
torch = { version = "^2.1.1", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
|
@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -27,20 +27,20 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -27,20 +27,20 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -19,6 +19,7 @@ class Quantization(str, Enum):
|
||||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
eetq = "eetq"
|
||||
fp8 = "fp8"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
@ -23,6 +23,10 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
method_name = method_name.split("/")[-1]
|
||||
logger.exception(f"Method {method_name} encountered an error.")
|
||||
|
||||
# Runtime Error cannot be recovered from
|
||||
if isinstance(err, RuntimeError):
|
||||
exit(1)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -67,6 +67,7 @@ try:
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
@ -144,7 +145,7 @@ def get_model(
|
||||
if speculate is not None:
|
||||
if speculate > speculate_medusa:
|
||||
raise RuntimeError(
|
||||
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||
)
|
||||
else:
|
||||
set_speculate(speculate)
|
||||
@ -186,6 +187,14 @@ def get_model(
|
||||
raise RuntimeError(
|
||||
f"Could not determine model type for {model_id} revision {revision}"
|
||||
)
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
if method in {"gptq", "awq"}:
|
||||
logger.info(f"Auto selecting quantization method {method}")
|
||||
quantize = method
|
||||
else:
|
||||
logger.info(f"Unknown quantization method {method}")
|
||||
|
||||
if model_type == "ssm":
|
||||
return Mamba(
|
||||
@ -571,6 +580,19 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||
|
||||
if sharded:
|
||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
|
@ -43,7 +43,7 @@ class CacheManager:
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
0, num_blocks * self.block_size, dtype=torch.int64
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(
|
||||
@ -55,9 +55,10 @@ class CacheManager:
|
||||
):
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= blocks
|
||||
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
if blocks > len(free_block_indices):
|
||||
raise RuntimeError(
|
||||
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
)
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[:blocks]
|
||||
|
827
server/text_generation_server/models/custom_modeling/clip.py
Normal file
827
server/text_generation_server/models/custom_modeling/clip.py
Normal file
@ -0,0 +1,827 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_create_4d_causal_attention_mask,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
# TODO Should we TP this ?
|
||||
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype)
|
||||
) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, embed_dim
|
||||
)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = (
|
||||
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = self.embed_dim // self.num_heads
|
||||
if self.head_size * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
self.scale = self.head_size**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
]
|
||||
* 3,
|
||||
dim=2,
|
||||
)
|
||||
query_states = query_states * self.scale
|
||||
key_states = self._shape(key_states, -1, bsz)
|
||||
value_states = self._shape(value_states, -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ causal_attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_probs = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIPAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPPreTrainedModel(nn.Module):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = CLIPConfig
|
||||
base_model_prefix = "clip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
||||
CLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
"""
|
||||
|
||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
"""
|
||||
|
||||
CLIP_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
"""
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`CLIPEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CLIPEncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _create_4d_causal_attention_mask(
|
||||
input_shape, hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(
|
||||
attention_mask, hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
|
||||
dim=-1
|
||||
),
|
||||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||
(
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
|
||||
== self.eos_token_id
|
||||
)
|
||||
.int()
|
||||
.argmax(dim=-1),
|
||||
]
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
class CLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CLIPTextModel
|
||||
|
||||
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.pre_layrnorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
last_hidden_state = encoder_outputs
|
||||
# pooled_output = last_hidden_state[:, 0, :]
|
||||
# pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
# pooler_output=pooled_output,
|
||||
# hidden_states=encoder_outputs,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_model = CLIPVisionTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPVisionModel
|
||||
|
||||
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
text_config = config.text_config
|
||||
vision_config = config.vision_config
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
self.text_model = CLIPTextTransformer(text_config)
|
||||
self.vision_model = CLIPVisionTransformer(vision_config)
|
||||
|
||||
self.visual_projection = nn.Linear(
|
||||
self.vision_embed_dim, self.projection_dim, bias=False
|
||||
)
|
||||
self.text_projection = nn.Linear(
|
||||
self.text_embed_dim, self.projection_dim, bias=False
|
||||
)
|
||||
self.logit_scale = nn.Parameter(
|
||||
torch.tensor(self.config.logit_scale_init_value)
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_features = self.text_projection(pooled_output)
|
||||
|
||||
return text_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> image_features = model.get_image_features(**inputs)
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs[1]
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
return logits_per_image, logits_per_text
|
@ -23,10 +23,10 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -34,65 +34,106 @@ from text_generation_server.utils.layers import (
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
class CohereConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
|
||||
class CohereRotary(PositionRotaryEmbedding):
|
||||
def forward(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=8192,
|
||||
intermediate_size=22528,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=5,
|
||||
eos_token_id=255001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
logit_scale=1.0,
|
||||
**kwargs,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
# Such controlflows may add some overhead.
|
||||
if IS_CUDA_SYSTEM:
|
||||
import rotary_emb
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
q1 = query[..., ::2]
|
||||
q2 = query[..., 1::2]
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.logit_scale = logit_scale
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
k1 = key[..., ::2]
|
||||
k2 = key[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
head_size = query.shape[-1]
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
|
||||
|
||||
class CohereLayerNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps):
|
||||
super().__init__()
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
self.weight = nn.Parameter(weight)
|
||||
# Fake weights
|
||||
self.ones = weight.new_ones(weight.shape[1])
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
hidden_states = hidden_states.reshape(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
hidden_states_minus_mean = hidden_states - mean
|
||||
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
None,
|
||||
self.ones,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# Required to apply one weight matrix per head
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
hidden_states = self.weight * hidden_states
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
@ -154,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.rotary_emb = CohereRotary.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
@ -175,6 +216,22 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = CohereLayerNorm(
|
||||
prefix=f"{prefix}.q_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.k_norm = CohereLayerNorm(
|
||||
prefix=f"{prefix}.k_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
@ -199,21 +256,28 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
query, key, value = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
query = query.reshape(-1, self.head_size)
|
||||
key = key.reshape(-1, self.head_size)
|
||||
query = self.q_norm(query.contiguous())
|
||||
key = self.k_norm(key.contiguous())
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
@ -223,8 +287,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
@ -298,7 +362,7 @@ class FlashCohereLayer(nn.Module):
|
||||
)
|
||||
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
self.input_layernorm = FastLayerNorm.load_no_bias(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
@ -362,7 +426,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
self.norm = FastLayerNorm.load_no_bias(
|
||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
|
@ -16,14 +16,13 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
@ -37,14 +36,6 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
import stk
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
logger.warning("Dbrx: megablocks is not installed")
|
||||
HAS_MEGABLOCKS = False
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
@ -531,18 +522,6 @@ def round_up(x: torch.Tensor, value: int):
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
"""
|
||||
Built on the paper and library Megablocks as described in
|
||||
https://arxiv.org/abs/2211.15841. This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accomodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||
super().__init__()
|
||||
self.moe_normalize_expert_weights = (
|
||||
@ -572,241 +551,40 @@ class BlockSparseMoE(nn.Module):
|
||||
)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights)
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.wv1 = torch.cat([w1, v1], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
# Calculate the number of bits needed to represent the expert indices
|
||||
# so that we can pass it to radix sort.
|
||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||
self.blocking = 128
|
||||
self.quantize_scatter_num_bits = -1
|
||||
|
||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||
padded_tokens, _ = x.size()
|
||||
assert padded_tokens % self.blocking == 0
|
||||
assert self.ffn_dim % self.blocking == 0
|
||||
|
||||
# Offsets for the sparse matrix. All rows have the
|
||||
# same number of nonzero blocks dictated by the
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[: block_rows + 1]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
# on the mapping of tokens to experts.
|
||||
column_indices = ops.topology(
|
||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||
)
|
||||
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
self.blocking,
|
||||
self.blocking,
|
||||
dtype=x.dtype,
|
||||
device="meta",
|
||||
)
|
||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||
return stk.Matrix(
|
||||
shape,
|
||||
data,
|
||||
row_indices,
|
||||
column_indices,
|
||||
offsets,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||
|
||||
# Histogram the expert ids to identify the number of
|
||||
# tokens routed to each expert.
|
||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||
|
||||
# Round the token counts up to the block size used in
|
||||
# the matrix muliplications. Caculate the starting
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||
padded_bins = promote_scalar(padded_bins)
|
||||
# padded_bins => [128, 128, 256, ...]
|
||||
|
||||
# Calculate the bin bounds for the sorted tokens.
|
||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||
bins = promote_scalar(bins)
|
||||
# bins => [3, 3, 5, ...]
|
||||
|
||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||
|
||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
selected_experts, weights = select_experts(
|
||||
gate_logits, self.top_k, self.moe_normalize_expert_weights
|
||||
)
|
||||
|
||||
(
|
||||
indices,
|
||||
bin_ids,
|
||||
bins,
|
||||
padded_bins,
|
||||
_,
|
||||
) = self.indices_and_padded_bins(selected_experts)
|
||||
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
topo = self.topology(x, padded_bins)
|
||||
|
||||
# Perform the expert computation
|
||||
# First Dense x Dense -> Sparse for w1 and v1,
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||
* stk.ops.sdd(x, self.v1.t(), topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
topo.column_indices_t,
|
||||
topo.offsets_t,
|
||||
topo.block_offsets_t,
|
||||
)
|
||||
|
||||
# Then Sparse x Dense -> Dense for w2
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = stk.ops.dsd(x, self.w2)
|
||||
|
||||
# Permute back and remove padding
|
||||
# (sequence_length, model_dim)
|
||||
x = ops.padded_scatter(
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
indices,
|
||||
bin_ids,
|
||||
weights,
|
||||
bins,
|
||||
padded_bins,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.quantize_scatter_num_bits,
|
||||
).view(*input_shape)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(x, group=self.process_group)
|
||||
|
||||
return x.view(*input_shape)
|
||||
|
||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
weights,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
weights.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
if self.moe_normalize_expert_weights:
|
||||
weights = weights / torch.norm(
|
||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||
)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
|
||||
# Permute to [num_experts, model_dim, ffn_dim]
|
||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
renormalize=self.moe_normalize_expert_weights,
|
||||
inplace=True,
|
||||
)
|
||||
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
|
||||
|
||||
out = torch.bmm(
|
||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
)
|
||||
# Mask not selected experts
|
||||
out *= weights.t().view(self.num_experts, -1, 1)
|
||||
|
||||
# Sum experts
|
||||
out = out.sum(0)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||
return self.sparse_forward(x)
|
||||
# This is faster when there is not a lot of tokens
|
||||
return self.dense_forward(x)
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
|
@ -281,9 +281,8 @@ class LlamaMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
@ -337,27 +336,30 @@ class FlashLlamaLayer(nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -368,7 +370,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -376,8 +378,10 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
@ -406,13 +410,19 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
@ -426,10 +436,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
@ -437,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
@ -343,27 +342,24 @@ class MistralLayer(nn.Module):
|
||||
|
||||
|
||||
class MistralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MistralLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -374,7 +370,7 @@ class MistralModel(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -384,9 +380,8 @@ class MistralModel(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
@ -410,18 +405,27 @@ class MistralModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = MistralModel(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
@ -453,8 +457,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -24,6 +24,7 @@ import torch.distributed
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
@ -41,14 +42,6 @@ from text_generation_server.utils.layers import (
|
||||
get_linear,
|
||||
)
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
import stk
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
logger.warning("Mixtral: megablocks is not installed")
|
||||
HAS_MEGABLOCKS = False
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
model_type = "mixtral"
|
||||
@ -321,18 +314,6 @@ def round_up(x: torch.Tensor, value: int):
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
"""
|
||||
Built on the paper and library Megablocks as described in
|
||||
https://arxiv.org/abs/2211.15841. This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accomodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
@ -357,236 +338,40 @@ class BlockSparseMoE(nn.Module):
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
# Calculate the number of bits needed to represent the expert indices
|
||||
# so that we can pass it to radix sort.
|
||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||
self.blocking = 128
|
||||
self.quantize_scatter_num_bits = -1
|
||||
|
||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||
padded_tokens, _ = x.size()
|
||||
assert padded_tokens % self.blocking == 0
|
||||
assert self.ffn_dim % self.blocking == 0
|
||||
|
||||
# Offsets for the sparse matrix. All rows have the
|
||||
# same number of nonzero blocks dictated by the
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[: block_rows + 1]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
# on the mapping of tokens to experts.
|
||||
column_indices = ops.topology(
|
||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||
)
|
||||
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
self.blocking,
|
||||
self.blocking,
|
||||
dtype=x.dtype,
|
||||
device="meta",
|
||||
)
|
||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||
return stk.Matrix(
|
||||
shape,
|
||||
data,
|
||||
row_indices,
|
||||
column_indices,
|
||||
offsets,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||
|
||||
# Histogram the expert ids to identify the number of
|
||||
# tokens routed to each expert.
|
||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||
|
||||
# Round the token counts up to the block size used in
|
||||
# the matrix muliplications. Caculate the starting
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||
padded_bins = promote_scalar(padded_bins)
|
||||
# padded_bins => [128, 128, 256, ...]
|
||||
|
||||
# Calculate the bin bounds for the sorted tokens.
|
||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||
bins = promote_scalar(bins)
|
||||
# bins => [3, 3, 5, ...]
|
||||
|
||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||
|
||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
||||
|
||||
(
|
||||
indices,
|
||||
bin_ids,
|
||||
bins,
|
||||
padded_bins,
|
||||
_,
|
||||
) = self.indices_and_padded_bins(selected_experts)
|
||||
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
topo = self.topology(x, padded_bins)
|
||||
|
||||
# Perform the expert computation
|
||||
# First Dense x Dense -> Sparse for w1 and w3,
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||
* stk.ops.sdd(x, self.w3.t(), topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
topo.column_indices_t,
|
||||
topo.offsets_t,
|
||||
topo.block_offsets_t,
|
||||
)
|
||||
|
||||
# Then Sparse x Dense -> Dense for w2
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = stk.ops.dsd(x, self.w2)
|
||||
|
||||
# Permute back and remove padding
|
||||
# (sequence_length, model_dim)
|
||||
x = ops.padded_scatter(
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
indices,
|
||||
bin_ids,
|
||||
weights,
|
||||
bins,
|
||||
padded_bins,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.quantize_scatter_num_bits,
|
||||
).view(*input_shape)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(x, group=self.process_group)
|
||||
|
||||
return x.view(*input_shape)
|
||||
|
||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
all_probs,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
all_probs.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
|
||||
# Permute to [num_experts, model_dim, ffn_dim]
|
||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
|
||||
out = torch.bmm(
|
||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
)
|
||||
# Mask not selected experts
|
||||
out *= weights.t().view(self.num_experts, -1, 1)
|
||||
|
||||
# Sum experts
|
||||
out = out.sum(0)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||
return self.sparse_forward(x)
|
||||
# This is faster when there is not a lot of tokens
|
||||
return self.dense_forward(x)
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
@ -679,9 +464,9 @@ class DenseMoE(nn.Module):
|
||||
|
||||
|
||||
class MixtralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
self.self_attn = MixtralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
@ -740,16 +525,20 @@ class MixtralLayer(nn.Module):
|
||||
|
||||
|
||||
class MixtralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralLayer(
|
||||
"model" if not prefix else f"{prefix}.model",
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
@ -758,7 +547,9 @@ class MixtralModel(torch.nn.Module):
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
@ -808,13 +599,13 @@ class MixtralModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MixtralModel(config, weights)
|
||||
self.model = MixtralModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
@ -0,0 +1,302 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_vision_model(prefix, config, weights):
|
||||
if config.model_type == "clip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.clip import (
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
def load_text_model(prefix, config, weights):
|
||||
if config.model_type == "llama":
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
self.language_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -106,6 +106,19 @@ class FlashCausalLMBatch(Batch):
|
||||
max_tokens=self.blocks * BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer):
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
return batch_tokenized_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
@ -114,16 +127,7 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
@ -165,6 +169,11 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
if (
|
||||
tokenized_input[0] == tokenizer.bos_token_id
|
||||
and tokenized_input[1] == tokenizer.bos_token_id
|
||||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
@ -690,7 +699,7 @@ class FlashCausalLM(Model):
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
@ -805,7 +814,7 @@ class FlashCausalLM(Model):
|
||||
for bs in CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except Exception:
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
@ -865,22 +874,14 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if (
|
||||
cu_seqlen_prefill is not None
|
||||
or cuda_graph is None
|
||||
or batch.speculative_ids is not None
|
||||
):
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -3,12 +3,11 @@ import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
CohereConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
@ -32,7 +31,7 @@ class FlashCohere(FlashCausalLM):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
||||
|
||||
@ -46,7 +45,7 @@ class FlashCohere(FlashCausalLM):
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = CohereConfig.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM):
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
@ -6,8 +6,7 @@ import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
@ -65,19 +64,21 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
batch_tokenized_inputs,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
@ -301,14 +302,15 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config_cls,
|
||||
model_cls,
|
||||
model_id: str,
|
||||
config_cls=AutoConfig,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -317,22 +319,13 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
else:
|
||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||
|
||||
try:
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = config_cls.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
@ -341,10 +334,12 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
else:
|
||||
config.sliding_window = None
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@ -353,17 +348,19 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_cls(config, weights)
|
||||
prefix = ""
|
||||
model = model_cls(prefix, config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
@ -371,6 +368,16 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> int:
|
||||
return self.model.max_past
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
@ -378,7 +385,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
@ -485,11 +492,11 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.model.max_past is not None:
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.model.max_past, max_s)
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torch
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -20,29 +21,13 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models.vlm_causal_lm import split
|
||||
|
||||
import re
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string):
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append(string[cursor:start])
|
||||
|
||||
parts.append(pattern.group(1))
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append(string[cursor:])
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@ -93,10 +78,21 @@ class IdeficsCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor: ProcessorMixin, # Hack
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
@ -127,10 +123,14 @@ class IdeficsCausalLMBatch(Batch):
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# TODO Check impact on idefics
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
prompts.append(split(inp))
|
||||
prompt = []
|
||||
for chunk in split(inp):
|
||||
prompt.append(chunk["content"])
|
||||
prompts.append(prompt)
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
@ -141,7 +141,8 @@ class IdeficsCausalLMBatch(Batch):
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
# TODO Check impact on idefics
|
||||
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
@ -156,7 +157,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
pixel_values = tokenized_inputs["pixel_values"]
|
||||
pixel_values = tokenized_inputs.get("pixel_values", None)
|
||||
image_hidden_states = None
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
@ -165,16 +166,19 @@ class IdeficsCausalLMBatch(Batch):
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
# Do the same for image_attention_mask
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
tokenized_inputs["pixel_values"].size(1),
|
||||
if pixel_values is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
pixel_values.size(1),
|
||||
)
|
||||
)
|
||||
)
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
@ -677,19 +681,22 @@ class IdeficsCausalLM(Model):
|
||||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
if batch.image_attention_mask is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
|
36
server/text_generation_server/models/llava_next.py
Normal file
36
server/text_generation_server/models/llava_next.py
Normal file
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class LlavaNext(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
329
server/text_generation_server/models/vlm_causal_lm.py
Normal file
329
server/text_generation_server/models/vlm_causal_lm.py
Normal file
@ -0,0 +1,329 @@
|
||||
import re
|
||||
import torch
|
||||
import math
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistralBatch,
|
||||
)
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string) -> List[Dict[str, str]]:
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append({"type": "text", "content": string[cursor:start]})
|
||||
|
||||
parts.append({"type": "image", "content": pattern.group(1)})
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append({"type": "text", "content": string[cursor:]})
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||||
image_grid_pinpoints = config.image_grid_pinpoints
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
|
||||
assert image_size % patch_size == 0
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
)
|
||||
|
||||
height_of_patch = math.ceil(height / width * npatches)
|
||||
|
||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||
# They are only added after width
|
||||
newline_features = height_of_patch * num_patch_width
|
||||
# The base patch covers the entire image
|
||||
base_features = npatches**2
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def load_data_uri(image_uri: str) -> Image.Image:
|
||||
image_uri = image_uri.split(",")[-1]
|
||||
content = base64.b64decode(image_uri)
|
||||
image = Image.open(BytesIO(content))
|
||||
return image
|
||||
|
||||
|
||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||||
# assert get_number_of_features(640, 640) == 2928
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashMistralBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
# On the rust layer.
|
||||
# This avoid making n queries per TP
|
||||
# if image.startswith("https://") or image.startswith("http://"):
|
||||
# image = processor.image_processor.fetch_images(image)
|
||||
if image.startswith("data:"):
|
||||
image = load_data_uri(image)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
full_text += "<image>" * num_features
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||
}
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor,
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(BaseFlashMistral):
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self, batch: VlmCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
image_sizes=batch.image_sizes,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
@ -13,6 +13,7 @@ from typing import List, Optional
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
@ -78,13 +79,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
@ -100,13 +103,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
|
@ -88,6 +88,9 @@ def attention(
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -19,7 +19,6 @@ from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -35,12 +34,6 @@ except Exception:
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
# V2 = False
|
||||
# log_once(
|
||||
# logger.warning,
|
||||
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
||||
# )
|
||||
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
@ -174,6 +167,8 @@ class EETQLinear(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
if weight.dtype != torch.float16:
|
||||
weight = weight.to(dtype=torch.float16)
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
|
||||
@ -187,6 +182,48 @@ class EETQLinear(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||
device = weight.device
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
|
||||
class Fp8Linear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = weight.dtype
|
||||
self.qweight, self.scale = fp8_quantize(weight)
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
qinput, scale = fp8_quantize(input)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
self.qweight.t(),
|
||||
out_dtype=self.dtype,
|
||||
scale_a=scale,
|
||||
scale_b=self.scale,
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -298,6 +335,8 @@ def get_linear(weight, bias, quantize):
|
||||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
elif quantize == "fp8":
|
||||
linear = Fp8Linear(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
warn_deprecate_bnb()
|
||||
linear = Linear8bitLt(
|
||||
@ -393,12 +432,12 @@ class ResBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, config, medusa_config, weights):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
|
||||
@ -408,12 +447,12 @@ class MedusaModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, medusa_config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
for i in range(medusa_config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
@ -428,7 +467,7 @@ class MedusaHead(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
class MedusaHeadV1(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
@ -436,38 +475,156 @@ class SpeculativeHead(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, medusa_config, weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MedusaHeadV1(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if input.shape[0] > 128:
|
||||
return logits, None
|
||||
|
||||
speculative_logits = self.medusa(input)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHeadV2(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
||||
|
||||
assert medusa_config["medusa_num_layers"] == 1
|
||||
self.linear = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
|
||||
def forward(self, x):
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if x.shape[0] > 128:
|
||||
logits = self.lm_head(x)
|
||||
return logits, None
|
||||
|
||||
size = x.shape[-1]
|
||||
block_size = (size + self.world_size - 1) // self.world_size
|
||||
start = self.rank * block_size
|
||||
stop = (self.rank + 1) * block_size
|
||||
|
||||
x_block = x[:, start:stop]
|
||||
|
||||
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||
medusa_res = self.act(self.linear(x)).reshape(
|
||||
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||
)
|
||||
|
||||
# Apply all residual medusa heads
|
||||
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||
|
||||
# Gather medusa heads
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
|
||||
# Stack x and medusa residual x
|
||||
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||
|
||||
# Compute lm head on x + medusa residual x
|
||||
logits = self.lm_head(stacked_x)
|
||||
|
||||
# Finally, split logits from speculative logits
|
||||
logits, speculative_logits = torch.split(
|
||||
logits, [1, self.n_medusa_heads], dim=-2
|
||||
)
|
||||
# Squeeze added dimension
|
||||
logits = logits.squeeze(-2)
|
||||
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
weights.routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, weights)
|
||||
lm_head = None
|
||||
try:
|
||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||
except:
|
||||
medusa = MedusaHeadV2(config, prefix, weights)
|
||||
else:
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||
return logits, speculative_logits
|
||||
if self.medusa is not None:
|
||||
return self.medusa(input)
|
||||
|
||||
assert self.head is not None
|
||||
logits = self.head(input)
|
||||
return logits, None
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
|
@ -1,8 +1,6 @@
|
||||
import torch
|
||||
|
||||
# vllm imports
|
||||
from vllm import cache_ops
|
||||
from vllm import attention_ops
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
@ -14,7 +12,18 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
if IS_CUDA_SYSTEM:
|
||||
from vllm._C import cache_ops
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import cache_ops
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
||||
|
||||
def attention(
|
||||
@ -54,21 +63,45 @@ def attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
attention_ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
if IS_CUDA_SYSTEM:
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import attention_ops
|
||||
|
||||
attention_ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
@ -83,19 +116,46 @@ def attention(
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import attention_ops
|
||||
|
||||
attention_ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
5
tgi-entrypoint.sh
Executable file
5
tgi-entrypoint.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
|
||||
|
||||
text-generation-launcher $@
|
Loading…
Reference in New Issue
Block a user