mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into ci_amd3
This commit is contained in:
commit
59849777de
2
.github/workflows/build.yaml
vendored
2
.github/workflows/build.yaml
vendored
@ -225,7 +225,7 @@ jobs:
|
||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
env:
|
||||
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
599
Cargo.lock
generated
599
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "2.0.5-dev0"
|
||||
version = "2.1.1-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -75,9 +75,11 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
|
||||
|
||||
```shell
|
||||
model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
# share a volume with the Docker container to avoid downloading weights every run
|
||||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
@ -91,7 +93,7 @@ curl 127.0.0.1:8080/generate_stream \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.0-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
||||
|
@ -1,84 +0,0 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6230469,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.046875,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1425781,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.9238281,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.076660156,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10821533,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.2539062,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 578,
|
||||
"logprob": -0.15563965,
|
||||
"special": false,
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 3622,
|
||||
"logprob": -0.8203125,
|
||||
"special": false,
|
||||
"text": " server"
|
||||
},
|
||||
{
|
||||
"id": 706,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " has"
|
||||
},
|
||||
{
|
||||
"id": 539,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " not"
|
||||
},
|
||||
{
|
||||
"id": 3686,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " yet"
|
||||
},
|
||||
{
|
||||
"id": 3288,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " sent"
|
||||
},
|
||||
{
|
||||
"id": 904,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " any"
|
||||
},
|
||||
{
|
||||
"id": 828,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " data"
|
||||
},
|
||||
{
|
||||
"id": 382,
|
||||
"logprob": -1.5517578,
|
||||
"special": false,
|
||||
"text": ".\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
||||
}
|
@ -1,338 +0,0 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
}
|
||||
]
|
File diff suppressed because it is too large
Load Diff
@ -8,61 +8,61 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.13000488,
|
||||
"logprob": -0.08660889,
|
||||
"special": false,
|
||||
"text": " A"
|
||||
},
|
||||
{
|
||||
"id": 13088,
|
||||
"logprob": -0.6713867,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": " chicken"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.2980957,
|
||||
"logprob": -0.32885742,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6398,
|
||||
"logprob": -0.060638428,
|
||||
"logprob": -0.05126953,
|
||||
"special": false,
|
||||
"text": " sitting"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.27319336,
|
||||
"logprob": -0.35229492,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.140625,
|
||||
"logprob": -0.12561035,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 17972,
|
||||
"logprob": -0.040405273,
|
||||
"logprob": -0.038085938,
|
||||
"special": false,
|
||||
"text": " pile"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.0002708435,
|
||||
"logprob": -0.00018656254,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2445,
|
||||
"logprob": -0.095336914,
|
||||
"logprob": -0.07293701,
|
||||
"special": false,
|
||||
"text": " money"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.0068359375,
|
||||
"logprob": -0.004852295,
|
||||
"special": false,
|
||||
"text": "."
|
||||
}
|
||||
|
@ -8,115 +8,115 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 415,
|
||||
"logprob": -0.04421997,
|
||||
"logprob": -0.039886475,
|
||||
"special": false,
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 12072,
|
||||
"logprob": -0.13500977,
|
||||
"logprob": -0.1430664,
|
||||
"special": false,
|
||||
"text": " cow"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.06750488,
|
||||
"logprob": -0.056488037,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6328,
|
||||
"logprob": -0.6352539,
|
||||
"logprob": -0.6855469,
|
||||
"special": false,
|
||||
"text": " standing"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.16186523,
|
||||
"logprob": -0.1685791,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.5078125,
|
||||
"logprob": -0.50097656,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 10305,
|
||||
"logprob": -0.017913818,
|
||||
"logprob": -0.017303467,
|
||||
"special": false,
|
||||
"text": " beach"
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": -1.5205078,
|
||||
"logprob": -1.3564453,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.029174805,
|
||||
"logprob": -0.017868042,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 13088,
|
||||
"logprob": -0.003479004,
|
||||
"logprob": -0.0027103424,
|
||||
"special": false,
|
||||
"text": " chicken"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.0035095215,
|
||||
"logprob": -0.003156662,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6398,
|
||||
"logprob": -0.3088379,
|
||||
"logprob": -0.37304688,
|
||||
"special": false,
|
||||
"text": " sitting"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.027755737,
|
||||
"logprob": -0.034576416,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.31884766,
|
||||
"logprob": -0.29418945,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 17972,
|
||||
"logprob": -0.047943115,
|
||||
"logprob": -0.042877197,
|
||||
"special": false,
|
||||
"text": " pile"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.0002925396,
|
||||
"logprob": -0.00028443336,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2445,
|
||||
"logprob": -0.02935791,
|
||||
"logprob": -0.023223877,
|
||||
"special": false,
|
||||
"text": " money"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.031219482,
|
||||
"logprob": -0.018157959,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 32002,
|
||||
"logprob": -0.00034475327,
|
||||
"logprob": -0.00018393993,
|
||||
"special": true,
|
||||
"text": "<end_of_utterance>"
|
||||
},
|
||||
|
@ -898,13 +898,20 @@ enum LauncherError {
|
||||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
fn download_convert_model(
|
||||
model_id: &str,
|
||||
revision: Option<&str>,
|
||||
trust_remote_code: bool,
|
||||
huggingface_hub_cache: Option<&str>,
|
||||
weights_cache_override: Option<&str>,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<(), LauncherError> {
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
|
||||
let mut download_args = vec![
|
||||
"download-weights".to_string(),
|
||||
args.model_id.to_string(),
|
||||
model_id.to_string(),
|
||||
"--extension".to_string(),
|
||||
".safetensors".to_string(),
|
||||
"--logger-level".to_string(),
|
||||
@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
];
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = &args.revision {
|
||||
if let Some(revision) = &revision {
|
||||
download_args.push("--revision".to_string());
|
||||
download_args.push(revision.to_string())
|
||||
}
|
||||
|
||||
// Trust remote code for automatic peft fusion
|
||||
if args.trust_remote_code {
|
||||
if trust_remote_code {
|
||||
download_args.push("--trust-remote-code".to_string());
|
||||
}
|
||||
|
||||
@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
|
||||
// If huggingface_hub_cache is set, pass it to the download process
|
||||
// Useful when running inside a docker container
|
||||
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
||||
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
|
||||
// If args.weights_cache_override is some, pass it to the download process
|
||||
// Useful when running inside a HuggingFace Inference Endpoint
|
||||
if let Some(weights_cache_override) = &args.weights_cache_override {
|
||||
if let Some(weights_cache_override) = &weights_cache_override {
|
||||
envs.push((
|
||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||
weights_cache_override.into(),
|
||||
@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
};
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting download process.");
|
||||
tracing::info!("Starting check and download process for {model_id}");
|
||||
let mut download_process = match Command::new("text-generation-server")
|
||||
.args(download_args)
|
||||
.env_clear()
|
||||
@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
loop {
|
||||
if let Some(status) = download_process.try_wait().unwrap() {
|
||||
if status.success() {
|
||||
tracing::info!("Successfully downloaded weights.");
|
||||
tracing::info!("Successfully downloaded weights for {model_id}");
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Download and convert model weights
|
||||
download_convert_model(&args, running.clone())?;
|
||||
download_convert_model(
|
||||
&args.model_id,
|
||||
args.revision.as_deref(),
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
|
||||
// Download and convert lora adapters if any
|
||||
if let Some(lora_adapters) = &args.lora_adapters {
|
||||
for adapter in lora_adapters.split(',') {
|
||||
download_convert_model(
|
||||
adapter,
|
||||
None,
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
// Launcher was asked to stop
|
||||
|
@ -22,9 +22,10 @@ text-generation-client = { path = "client" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
itertools = "0.10"
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.21.1"
|
||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
@ -37,9 +38,9 @@ tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing = "0.1.40"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
|
@ -71,10 +71,12 @@ fn get_unpadded_features(
|
||||
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||
let new_height = (height * current_width) / width;
|
||||
(new_height, current_width)
|
||||
let padding = (current_height - new_height) / 2;
|
||||
(current_height - (2 * padding), current_width)
|
||||
} else {
|
||||
let new_width = (width * current_height) / height;
|
||||
(current_height, new_width)
|
||||
let padding = (current_width - new_width) / 2;
|
||||
(current_height, current_width - (2 * padding))
|
||||
};
|
||||
|
||||
let unpadded_features = current_height * current_width;
|
||||
@ -88,7 +90,9 @@ impl LlavaNext {
|
||||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
// Dimensions are intentionally swapped to be bug-compatible with
|
||||
// upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
let (num_patch_width, num_patch_height) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
|
||||
let (unpadded_features, newline_features) =
|
||||
@ -112,7 +116,7 @@ pub struct Idefics2 {}
|
||||
|
||||
impl Idefics2 {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
320
|
||||
64
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,6 +162,7 @@ pub enum Config {
|
||||
Baichuan,
|
||||
Paligemma(Paligemma),
|
||||
Gemma,
|
||||
Gemma2,
|
||||
Cohere,
|
||||
Drbx,
|
||||
Falcon,
|
||||
|
@ -61,6 +61,9 @@ pub struct HubTokenizerConfig {
|
||||
pub bos_token: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub eos_token: Option<String>,
|
||||
pub tokenizer_class: Option<String>,
|
||||
pub add_bos_token: Option<bool>,
|
||||
pub add_eos_token: Option<bool>,
|
||||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
@ -70,6 +73,25 @@ impl HubTokenizerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "processor_class")]
|
||||
pub enum HubPreprocessorConfig {
|
||||
Idefics2Processor(Idefics2Preprocessor),
|
||||
}
|
||||
|
||||
impl HubPreprocessorConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Idefics2Preprocessor {
|
||||
#[serde(default)]
|
||||
do_image_splitting: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubProcessorConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
|
@ -13,9 +13,11 @@ use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||
use text_generation_router::{
|
||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
|
||||
use tower_http::cors::AllowOrigin;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
@ -214,6 +216,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
) = match api {
|
||||
@ -221,6 +224,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("preprocessor_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
None,
|
||||
),
|
||||
@ -237,6 +241,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||
@ -249,6 +254,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
)
|
||||
@ -263,13 +269,12 @@ async fn main() -> Result<(), RouterError> {
|
||||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("preprocessor_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let tokenizer: Option<Tokenizer> =
|
||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
@ -300,6 +305,23 @@ async fn main() -> Result<(), RouterError> {
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||
if let Some(tokenizer) = &mut tokenizer {
|
||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||
tokenizer.with_post_processor(post_processor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenizer
|
||||
});
|
||||
|
||||
let preprocessor_config =
|
||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
||||
let processor_config = processor_config_filename
|
||||
.and_then(HubProcessorConfig::from_file)
|
||||
.unwrap_or_default();
|
||||
@ -361,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
preprocessor_config,
|
||||
processor_config,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
@ -504,6 +527,77 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
|
||||
Some(tokenizer_config)
|
||||
}
|
||||
|
||||
/// Create a post_processor for the LlamaTokenizer
|
||||
pub fn create_post_processor(
|
||||
tokenizer: &Tokenizer,
|
||||
tokenizer_config: &HubTokenizerConfig,
|
||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
||||
|
||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
||||
|
||||
if add_bos_token && bos_token.is_none() {
|
||||
panic!("add_bos_token = true but bos_token is None");
|
||||
}
|
||||
|
||||
if add_eos_token && eos_token.is_none() {
|
||||
panic!("add_eos_token = true but eos_token is None");
|
||||
}
|
||||
|
||||
let mut single = Vec::new();
|
||||
let mut pair = Vec::new();
|
||||
let mut special_tokens = Vec::new();
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
let bos_token_id = tokenizer
|
||||
.token_to_id(bos)
|
||||
.expect("Should have found the bos token id");
|
||||
special_tokens.push((bos.clone(), bos_token_id));
|
||||
single.push(format!("{}:0", bos));
|
||||
pair.push(format!("{}:0", bos));
|
||||
}
|
||||
}
|
||||
|
||||
single.push("$A:0".to_string());
|
||||
pair.push("$A:0".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(eos)
|
||||
.expect("Should have found the eos token id");
|
||||
special_tokens.push((eos.clone(), eos_token_id));
|
||||
single.push(format!("{}:0", eos));
|
||||
pair.push(format!("{}:0", eos));
|
||||
}
|
||||
}
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
pair.push(format!("{}:1", bos));
|
||||
}
|
||||
}
|
||||
|
||||
pair.push("$B:1".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
pair.push(format!("{}:1", eos));
|
||||
}
|
||||
}
|
||||
|
||||
let post_processor = TemplateProcessing::builder()
|
||||
.try_single(single)?
|
||||
.try_pair(pair)?
|
||||
.special_tokens(special_tokens)
|
||||
.build()?;
|
||||
|
||||
Ok(post_processor)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
@ -513,3 +607,36 @@ enum RouterError {
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_post_processor() {
|
||||
let tokenizer_config = HubTokenizerConfig {
|
||||
add_bos_token: None,
|
||||
add_eos_token: None,
|
||||
bos_token: Some("<s>".to_string()),
|
||||
eos_token: Some("</s>".to_string()),
|
||||
chat_template: None,
|
||||
tokenizer_class: None,
|
||||
completion_template: None,
|
||||
};
|
||||
|
||||
let tokenizer =
|
||||
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
|
||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
||||
|
||||
let expected = TemplateProcessing::builder()
|
||||
.try_single("<s>:0 $A:0 <s>:1")
|
||||
.unwrap()
|
||||
.try_pair("<s>:0 $A:0 $B:1")
|
||||
.unwrap()
|
||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(post_processor, expected);
|
||||
}
|
||||
}
|
||||
|
@ -12,9 +12,9 @@ use crate::kserve::{
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
||||
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
||||
Usage, Validation,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
|
||||
HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
|
||||
Token, TokenizeResponse, Usage, Validation,
|
||||
};
|
||||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
@ -1423,6 +1423,7 @@ pub async fn run(
|
||||
_ngrok_authtoken: Option<String>,
|
||||
_ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
processor_config: HubProcessorConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
@ -1636,6 +1637,7 @@ pub async fn run(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
config,
|
||||
preprocessor_config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
@ -1,13 +1,16 @@
|
||||
/// Payload validation logic
|
||||
use crate::config::Config;
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||
use crate::{
|
||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||
};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{io::Reader as ImageReader, ImageFormat};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use text_generation_client::{Chunk, Image, InputChunk};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
@ -36,6 +39,7 @@ impl Validation {
|
||||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
@ -53,12 +57,18 @@ impl Validation {
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let config_clone = config.clone();
|
||||
let preprocessor_config_clone = preprocessor_config.clone();
|
||||
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||
senders.push(tokenizer_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
|
||||
tokenizer_worker(
|
||||
tokenizer_clone,
|
||||
config_clone,
|
||||
preprocessor_config_clone,
|
||||
tokenizer_receiver,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
@ -422,13 +432,20 @@ async fn round_robin_task(
|
||||
fn tokenizer_worker(
|
||||
tokenizer: Tokenizer,
|
||||
config: Option<Config>,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, &config))
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
@ -508,16 +525,67 @@ fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), Validatio
|
||||
}
|
||||
}
|
||||
|
||||
fn image_tokens(
|
||||
config: &Config,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
) -> String {
|
||||
use Config::*;
|
||||
use HubPreprocessorConfig::*;
|
||||
match config {
|
||||
Idefics => "<image>".to_string(),
|
||||
Idefics2(config) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
const IMAGE: &str = "<image>";
|
||||
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
|
||||
let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len());
|
||||
image_string.push_str(FAKE);
|
||||
image_string.extend(iter::repeat(IMAGE).take(slots));
|
||||
image_string.push_str(FAKE);
|
||||
|
||||
if matches!(
|
||||
preprocessor_config,
|
||||
Some(Idefics2Processor(Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
..
|
||||
}))
|
||||
) {
|
||||
image_string = image_string.repeat(5);
|
||||
};
|
||||
|
||||
image_string
|
||||
}
|
||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
match config {
|
||||
Config::Idefics2(_) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
text.replace(&format!("{FAKE}{FAKE}"), FAKE)
|
||||
}
|
||||
_ => text,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
inputs: String,
|
||||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
config: &Option<Config>,
|
||||
config: Option<&Config>,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
@ -529,88 +597,17 @@ fn prepare_input(
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Paligemma(config)) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, _height, _width) =
|
||||
fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = 1;
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
||||
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||
@ -750,7 +747,7 @@ pub enum ValidationError {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PaliTextConfig, Paligemma};
|
||||
use crate::config::{Idefics2, PaliTextConfig, Paligemma};
|
||||
use crate::default_parameters;
|
||||
use crate::tests::get_tokenizer;
|
||||
|
||||
@ -769,6 +766,7 @@ mod tests {
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -803,6 +801,7 @@ mod tests {
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -836,6 +835,7 @@ mod tests {
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -874,6 +874,7 @@ mod tests {
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -941,6 +942,7 @@ mod tests {
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
@ -1026,6 +1028,7 @@ mod tests {
|
||||
workers,
|
||||
tokenizer,
|
||||
Some(config),
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
@ -1058,4 +1061,83 @@ mod tests {
|
||||
"Failed to process images",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_idefics2_correct_n_fake_tokens() {
|
||||
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = Config::Idefics2(Idefics2 {});
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
Some(config),
|
||||
Some(HubPreprocessorConfig::Idefics2Processor(
|
||||
Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
},
|
||||
)),
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let (encoding, chunks) = match validation
|
||||
.tokenize(
|
||||
format!(
|
||||
"test",
|
||||
PIXEL_GIF, PIXEL_GIF
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some((encoding, chunks))) => (encoding, chunks),
|
||||
_ => panic!("Unexpected tokenization failure"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
chunks
|
||||
== vec![
|
||||
Chunk::Text("test".to_string()).into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into()
|
||||
],
|
||||
"Failed to process images",
|
||||
);
|
||||
|
||||
// Verify the number of fake tokens:
|
||||
//
|
||||
// - Two images surrounded/separated by a fake token = 3.
|
||||
// - Both are split in 5 subimages, separated by a fake token: 2 * 4
|
||||
//
|
||||
// Fake tokens get split up by the testing tokenizer, but we don't care.
|
||||
assert_eq!(
|
||||
encoding
|
||||
.get_tokens()
|
||||
.iter()
|
||||
.filter(|t| *t == "fake")
|
||||
.count(),
|
||||
11
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
sym: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQWeight:
|
||||
qweight: torch.Tensor
|
||||
|
@ -166,12 +166,17 @@ def get_linear(weight, bias, quantize):
|
||||
|
||||
elif quantize == "gptq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if not isinstance(weight, GPTQWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
)
|
||||
|
||||
if isinstance(weight, GPTQMarlinWeight):
|
||||
linear = GPTQMarlinLinear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQWeight):
|
||||
if weight.use_exllama:
|
||||
try:
|
||||
from text_generation_server.layers.gptq import (
|
||||
@ -195,6 +200,11 @@ def get_linear(weight, bias, quantize):
|
||||
weight.bits,
|
||||
weight.groupsize,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
elif quantize == "awq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Linear,
|
||||
GPTQMarlin24Weight,
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
MarlinLinear,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
if isinstance(weight, GPTQMarlinWeight):
|
||||
linear = GPTQMarlinLinear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQMarlin24Weight):
|
||||
if isinstance(weight, GPTQMarlin24Weight):
|
||||
linear = GPTQMarlin24Linear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
|
@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
try:
|
||||
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and marlin_kernels is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and gptq_params.quant_method == "gptq"
|
||||
and gptq_params.bits in GPTQ_MARLIN_BITS
|
||||
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
and gptq_params.sym
|
||||
)
|
||||
|
||||
|
||||
def _check_marlin_kernels():
|
||||
if not (SYSTEM == "cuda" and has_sm_8_0):
|
||||
raise NotImplementedError(
|
||||
|
@ -68,6 +68,9 @@ try:
|
||||
from text_generation_server.models.flash_gemma import (
|
||||
FlashGemma,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma2 import (
|
||||
FlashGemma2,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemma,
|
||||
)
|
||||
@ -102,6 +105,7 @@ if FLASH_ATTENTION:
|
||||
__all__.append(FlashQwen2)
|
||||
__all__.append(FlashStarcoder2)
|
||||
__all__.append(FlashGemma)
|
||||
__all__.append(FlashGemma2)
|
||||
__all__.append(FlashCohere)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
@ -145,6 +149,11 @@ class ModelType(enum.Enum):
|
||||
"name": "Gemma",
|
||||
"url": "https://huggingface.co/google/gemma-7b",
|
||||
}
|
||||
GEMMA2 = {
|
||||
"type": "gemma2",
|
||||
"name": "Gemma2",
|
||||
"url": "https://huggingface.co/google/gemma2-9b",
|
||||
}
|
||||
COHERE = {
|
||||
"type": "cohere",
|
||||
"name": "Cohere",
|
||||
@ -637,6 +646,27 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == GEMMA2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma2(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -0,0 +1,500 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
||||
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256128,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class Gemma2FastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
weights.dtype = dtype
|
||||
new = cls(weight, eps)
|
||||
new.dtype = dtype
|
||||
return new
|
||||
|
||||
# perform the multiplication in full precision and downcast after
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states.to(self.dtype), residual
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.head_dim
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashGemma2Attention(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.head_dim
|
||||
self.causal = causal
|
||||
if is_sliding:
|
||||
self.window_size = config.sliding_window
|
||||
else:
|
||||
self.window_size = -1
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
# self.softmax_scale = self.head_size**-0.5
|
||||
self.softmax_scale = config.query_pre_attn_scalar**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class Gemma2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class FlashGemma2Layer(nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemma2Attention(
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
causal=causal,
|
||||
is_sliding=is_sliding,
|
||||
)
|
||||
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = Gemma2FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = Gemma2FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
|
||||
prefix=f"{prefix}.pre_feedforward_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_feedforward_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
|
||||
normed_attn_res_output = normed_attn_res_output + res
|
||||
res = normed_attn_res_output
|
||||
|
||||
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
||||
mlp_output = self.mlp(pre_normed)
|
||||
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
||||
|
||||
return post_hidden_states, normed_attn_res_output
|
||||
|
||||
|
||||
class FlashGemma2Model(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemma2Layer(
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
causal=causal,
|
||||
is_sliding=layer_id % 2 == 0,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = Gemma2FastRMSNorm.load(
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.model = FlashGemma2Model(
|
||||
prefix=prefix, config=config, weights=weights, causal=causal
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
prefix=(
|
||||
f"{prefix}.embed_tokens"
|
||||
if config.tie_word_embeddings
|
||||
else f"{prefix}.lm_head"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -375,8 +375,6 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
@ -39,7 +39,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -47,7 +47,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
@ -230,7 +230,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
|
@ -28,8 +28,12 @@ from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
import text_generation_server.models.globals as tgi_globals
|
||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||
from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
CUDA_GRAPHS,
|
||||
get_adapter_to_index,
|
||||
MODEL_ID,
|
||||
)
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
@ -233,7 +237,8 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(torch.full((input_length,), adapter_index))
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
@ -499,9 +504,8 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
|
||||
self.requests[idx].adapter_id, 0
|
||||
)
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
remaining_tokens = (
|
||||
|
75
server/text_generation_server/models/flash_gemma2.py
Normal file
75
server/text_generation_server/models/flash_gemma2.py
Normal file
@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import PretrainedConfig, AutoTokenizer
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGemma2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = PretrainedConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
# TODO hardcoded
|
||||
prefix = ""
|
||||
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -34,3 +34,8 @@ ADAPTER_TO_INDEX: Dict[str, int] = None
|
||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX = adapter_to_index
|
||||
|
||||
|
||||
def get_adapter_to_index():
|
||||
global ADAPTER_TO_INDEX
|
||||
return ADAPTER_TO_INDEX
|
||||
|
@ -39,7 +39,9 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
# TODO do_convert_RGB should be on by default ?
|
||||
image = image.convert("RGB")
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
full_text += image_text_replacement(
|
||||
processor, image_input, config, image_id
|
||||
)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
@ -1,3 +1,4 @@
|
||||
from itertools import repeat
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
@ -15,6 +16,9 @@ from text_generation_server.models.flash_mistral import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
@ -22,7 +26,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -39,15 +43,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(image_input, config, image_id) -> str:
|
||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
# TODO technically depends on image splitting which is not implemented.
|
||||
num_features = 320
|
||||
return (
|
||||
"<fake_token_around_image>"
|
||||
+ "<image>" * num_features
|
||||
+ "<fake_token_around_image>"
|
||||
)
|
||||
image_seq_len = 64
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||
if processor.image_processor.do_image_splitting:
|
||||
image_str *= 5
|
||||
return image_str
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
@ -64,20 +66,35 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def image_text_replacement_fixup(config, text: str) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio: float = width / height
|
||||
aspect_ratio: float = original_width / original_height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
@ -96,7 +113,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
@ -168,9 +187,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(image_inputs, config, image_id)
|
||||
full_text += image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
|
@ -1,25 +1,15 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
sym: bool
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
@ -212,6 +202,10 @@ class Weights:
|
||||
"""
|
||||
if quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
@ -221,17 +215,28 @@ class Weights:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
qzeros = self.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = self.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = self.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||
@ -269,7 +274,6 @@ class Weights:
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
B = self.get_packed_sharded(
|
||||
@ -286,31 +290,6 @@ class Weights:
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
gptq_params = self._get_gptq_params()
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = self.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
weight = repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
else:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
@ -356,6 +335,10 @@ class Weights:
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
elif quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
@ -366,14 +349,31 @@ class Weights:
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
@ -425,10 +425,8 @@ class Weights:
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
@ -452,36 +450,6 @@ class Weights:
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
gptq_params = self._get_gptq_params()
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
|
||||
dim=1,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
weight = repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
@ -544,9 +512,41 @@ class Weights:
|
||||
)
|
||||
|
||||
elif quantize == "gptq":
|
||||
use_exllama = True
|
||||
gptq_params = self._get_gptq_params()
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = self.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if gptq_params.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
@ -672,10 +672,8 @@ class Weights:
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
@ -698,35 +696,6 @@ class Weights:
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = self.process_group.size() > 1
|
||||
|
||||
weight = repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||
@ -743,18 +712,17 @@ class Weights:
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> _GPTQParams:
|
||||
def _get_gptq_params(self) -> GPTQParams:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = False
|
||||
sym = True
|
||||
sym = False
|
||||
quant_method = "gptq"
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
@ -767,7 +735,7 @@ class Weights:
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return _GPTQParams(
|
||||
return GPTQParams(
|
||||
bits=bits,
|
||||
checkpoint_format=checkpoint_format,
|
||||
desc_act=desc_act,
|
||||
|
Loading…
Reference in New Issue
Block a user