Merge branch 'main' into ci_amd3

This commit is contained in:
fxmarty 2024-07-01 14:14:46 +02:00
commit 59849777de
34 changed files with 7255 additions and 6516 deletions

View File

@ -225,7 +225,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

599
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "2.0.5-dev0" version = "2.1.1-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -75,9 +75,11 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
```shell ```shell
model=HuggingFaceH4/zephyr-7b-beta model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run # share a volume with the Docker container to avoid downloading weights every run
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -91,7 +93,7 @@ curl 127.0.0.1:8080/generate_stream \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.0-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \ --device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \ --ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \ ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \
--model-id $model --model-id $model
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.4 \ ghcr.io/huggingface/text-generation-inference:2.1.0 \
--model-id $model --model-id $model
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.4 \ ghcr.io/huggingface/text-generation-inference:2.1.0 \
--model-id $model --model-id $model
``` ```
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help
``` ```
</Tip> </Tip>

View File

@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6230469,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.076660156,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10821533,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -2.2539062,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15563965,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0,
"special": false,
"text": " has"
},
{
"id": 539,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 3686,
"logprob": 0.0,
"special": false,
"text": " yet"
},
{
"id": 3288,
"logprob": 0.0,
"special": false,
"text": " sent"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
}
],
"top_tokens": null
},
"generated_text": "Test request. The server has not yet sent any data.\n\n"
}

View File

@ -1,338 +0,0 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}
]

View File

@ -8,61 +8,61 @@
"tokens": [ "tokens": [
{ {
"id": 330, "id": 330,
"logprob": -0.13000488, "logprob": -0.08660889,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.6713867, "logprob": -0.7089844,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.2980957, "logprob": -0.32885742,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.060638428, "logprob": -0.05126953,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.27319336, "logprob": -0.35229492,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.140625, "logprob": -0.12561035,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.040405273, "logprob": -0.038085938,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.0002708435, "logprob": -0.00018656254,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.095336914, "logprob": -0.07293701,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.0068359375, "logprob": -0.004852295,
"special": false, "special": false,
"text": "." "text": "."
} }

View File

@ -8,115 +8,115 @@
"tokens": [ "tokens": [
{ {
"id": 415, "id": 415,
"logprob": -0.04421997, "logprob": -0.039886475,
"special": false, "special": false,
"text": " The" "text": " The"
}, },
{ {
"id": 12072, "id": 12072,
"logprob": -0.13500977, "logprob": -0.1430664,
"special": false, "special": false,
"text": " cow" "text": " cow"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.06750488, "logprob": -0.056488037,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6328, "id": 6328,
"logprob": -0.6352539, "logprob": -0.6855469,
"special": false, "special": false,
"text": " standing" "text": " standing"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.16186523, "logprob": -0.1685791,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.5078125, "logprob": -0.50097656,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 10305, "id": 10305,
"logprob": -0.017913818, "logprob": -0.017303467,
"special": false, "special": false,
"text": " beach" "text": " beach"
}, },
{ {
"id": 304, "id": 304,
"logprob": -1.5205078, "logprob": -1.3564453,
"special": false, "special": false,
"text": " and" "text": " and"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.029174805, "logprob": -0.017868042,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.003479004, "logprob": -0.0027103424,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.0035095215, "logprob": -0.003156662,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.3088379, "logprob": -0.37304688,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.027755737, "logprob": -0.034576416,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.31884766, "logprob": -0.29418945,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.047943115, "logprob": -0.042877197,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.0002925396, "logprob": -0.00028443336,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.02935791, "logprob": -0.023223877,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.031219482, "logprob": -0.018157959,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 32002, "id": 32002,
"logprob": -0.00034475327, "logprob": -0.00018393993,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },

View File

@ -898,13 +898,20 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(
model_id: &str,
revision: Option<&str>,
trust_remote_code: bool,
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
// Enter download tracing span // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![ let mut download_args = vec![
"download-weights".to_string(), "download-weights".to_string(),
args.model_id.to_string(), model_id.to_string(),
"--extension".to_string(), "--extension".to_string(),
".safetensors".to_string(), ".safetensors".to_string(),
"--logger-level".to_string(), "--logger-level".to_string(),
@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
]; ];
// Model optional revision // Model optional revision
if let Some(revision) = &args.revision { if let Some(revision) = &revision {
download_args.push("--revision".to_string()); download_args.push("--revision".to_string());
download_args.push(revision.to_string()) download_args.push(revision.to_string())
} }
// Trust remote code for automatic peft fusion // Trust remote code for automatic peft fusion
if args.trust_remote_code { if trust_remote_code {
download_args.push("--trust-remote-code".to_string()); download_args.push("--trust-remote-code".to_string());
} }
@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If huggingface_hub_cache is set, pass it to the download process // If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If args.weights_cache_override is some, pass it to the download process // If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override { if let Some(weights_cache_override) = &weights_cache_override {
envs.push(( envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}; };
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting check and download process for {model_id}");
let mut download_process = match Command::new("text-generation-server") let mut download_process = match Command::new("text-generation-server")
.args(download_args) .args(download_args)
.env_clear() .env_clear()
@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
loop { loop {
if let Some(status) = download_process.try_wait().unwrap() { if let Some(status) = download_process.try_wait().unwrap() {
if status.success() { if status.success() {
tracing::info!("Successfully downloaded weights."); tracing::info!("Successfully downloaded weights for {model_id}");
break; break;
} }
@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, running.clone())?; download_convert_model(
&args.model_id,
args.revision.as_deref(),
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
}
}
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop // Launcher was asked to stop

View File

@ -22,9 +22,10 @@ text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"
@ -37,9 +38,9 @@ tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.40"
tracing-opentelemetry = "0.21.0" tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }

View File

@ -71,10 +71,12 @@ fn get_unpadded_features(
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64; let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio { let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width; let new_height = (height * current_width) / width;
(new_height, current_width) let padding = (current_height - new_height) / 2;
(current_height - (2 * padding), current_width)
} else { } else {
let new_width = (width * current_height) / height; let new_width = (width * current_height) / height;
(current_height, new_width) let padding = (current_width - new_width) / 2;
(current_height, current_width - (2 * padding))
}; };
let unpadded_features = current_height * current_width; let unpadded_features = current_height * current_width;
@ -88,7 +90,9 @@ impl LlavaNext {
let patch_size = self.vision_config.patch_size; let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0); assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size; let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) = // Dimensions are intentionally swapped to be bug-compatible with
// upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
let (num_patch_width, num_patch_height) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
let (unpadded_features, newline_features) = let (unpadded_features, newline_features) =
@ -112,7 +116,7 @@ pub struct Idefics2 {}
impl Idefics2 { impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320 64
} }
} }
@ -158,6 +162,7 @@ pub enum Config {
Baichuan, Baichuan,
Paligemma(Paligemma), Paligemma(Paligemma),
Gemma, Gemma,
Gemma2,
Cohere, Cohere,
Drbx, Drbx,
Falcon, Falcon,

View File

@ -61,6 +61,9 @@ pub struct HubTokenizerConfig {
pub bos_token: Option<String>, pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>, pub eos_token: Option<String>,
pub tokenizer_class: Option<String>,
pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>,
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
@ -70,6 +73,25 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
}
impl HubPreprocessorConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).ok()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Idefics2Preprocessor {
#[serde(default)]
do_image_splitting: bool,
}
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig { pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,

View File

@ -13,9 +13,11 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -214,6 +216,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) = match api { ) = match api {
@ -221,6 +224,7 @@ async fn main() -> Result<(), RouterError> {
Some(local_path.join("tokenizer.json")), Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None, None,
), ),
@ -237,6 +241,7 @@ async fn main() -> Result<(), RouterError> {
}; };
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await { let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
@ -249,6 +254,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) )
@ -263,13 +269,12 @@ async fn main() -> Result<(), RouterError> {
repo.get("tokenizer.json"), repo.get("tokenizer.json"),
repo.get("config.json"), repo.get("config.json"),
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None, None,
) )
} }
}; };
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| { let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename) std::fs::read_to_string(filename)
.ok() .ok()
@ -300,6 +305,23 @@ async fn main() -> Result<(), RouterError> {
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file) .and_then(HubProcessorConfig::from_file)
.unwrap_or_default(); .unwrap_or_default();
@ -361,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
preprocessor_config,
processor_config, processor_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
@ -504,6 +527,77 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
Some(tokenizer_config) Some(tokenizer_config)
} }
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos)
.expect("Should have found the bos token id");
special_tokens.push((bos.clone(), bos_token_id));
single.push(format!("{}:0", bos));
pair.push(format!("{}:0", bos));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos)
.expect("Should have found the eos token id");
special_tokens.push((eos.clone(), eos_token_id));
single.push(format!("{}:0", eos));
pair.push(format!("{}:0", eos));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
enum RouterError { enum RouterError {
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]
@ -513,3 +607,36 @@ enum RouterError {
#[error("Tokio runtime failed to start: {0}")] #[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error), Tokio(#[from] std::io::Error),
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some("<s>".to_string()),
eos_token: Some("</s>".to_string()),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};
let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0 <s>:1")
.unwrap()
.try_pair("<s>:0 $A:0 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();
assert_eq!(post_processor, expected);
}
}

View File

@ -12,9 +12,9 @@ use crate::kserve::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
Usage, Validation, Token, TokenizeResponse, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -1423,6 +1423,7 @@ pub async fn run(
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig, processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
@ -1636,6 +1637,7 @@ pub async fn run(
validation_workers, validation_workers,
tokenizer, tokenizer,
config, config,
preprocessor_config,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,

View File

@ -1,13 +1,16 @@
/// Payload validation logic /// Payload validation logic
use crate::config::Config; use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat}; use image::{io::Reader as ImageReader, ImageFormat};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter;
use text_generation_client::{Chunk, Image, InputChunk}; use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -36,6 +39,7 @@ impl Validation {
workers: usize, workers: usize,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
config: Option<Config>, config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
@ -53,12 +57,18 @@ impl Validation {
for _ in 0..workers { for _ in 0..workers {
let tokenizer_clone = tokenizer.clone(); let tokenizer_clone = tokenizer.clone();
let config_clone = config.clone(); let config_clone = config.clone();
let preprocessor_config_clone = preprocessor_config.clone();
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
senders.push(tokenizer_sender); senders.push(tokenizer_sender);
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver) tokenizer_worker(
tokenizer_clone,
config_clone,
preprocessor_config_clone,
tokenizer_receiver,
)
}); });
} }
@ -422,13 +432,20 @@ async fn round_robin_task(
fn tokenizer_worker( fn tokenizer_worker(
tokenizer: Tokenizer, tokenizer: Tokenizer,
config: Option<Config>, config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
// Loop over requests // Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input(inputs, truncate, &tokenizer, &config)) .send(prepare_input(
inputs,
truncate,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(()) .unwrap_or(())
}) })
} }
@ -508,16 +525,67 @@ fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), Validatio
} }
} }
fn image_tokens(
config: &Config,
preprocessor_config: Option<&HubPreprocessorConfig>,
height: usize,
width: usize,
) -> String {
use Config::*;
use HubPreprocessorConfig::*;
match config {
Idefics => "<image>".to_string(),
Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
let slots = config.get_number_of_features(height, width);
let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len());
image_string.push_str(FAKE);
image_string.extend(iter::repeat(IMAGE).take(slots));
image_string.push_str(FAKE);
if matches!(
preprocessor_config,
Some(Idefics2Processor(Idefics2Preprocessor {
do_image_splitting: true,
..
}))
) {
image_string = image_string.repeat(5);
};
image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String {
match config {
Config::Idefics2(_) => {
const FAKE: &str = "<fake_token_around_image>";
text.replace(&format!("{FAKE}{FAKE}"), FAKE)
}
_ => text,
}
}
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: &Option<Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some(Config::LlavaNext(config)) => { Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
@ -529,88 +597,17 @@ fn prepare_input(
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
(tokenizer_query, input_chunks)
}
Some(Config::Paligemma(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics2(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); tokenizer_query = image_tokens_fixup(config, tokenizer_query);
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, _height, _width) =
fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
@ -750,7 +747,7 @@ pub enum ValidationError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::config::{PaliTextConfig, Paligemma}; use crate::config::{Idefics2, PaliTextConfig, Paligemma};
use crate::default_parameters; use crate::default_parameters;
use crate::tests::get_tokenizer; use crate::tests::get_tokenizer;
@ -769,6 +766,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -803,6 +801,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -836,6 +835,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -874,6 +874,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -941,6 +942,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
@ -1026,6 +1028,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
Some(config), Some(config),
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -1058,4 +1061,83 @@ mod tests {
"Failed to process images", "Failed to process images",
); );
} }
#[tokio::test]
async fn test_idefics2_correct_n_fake_tokens() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1;
let config = Config::Idefics2(Idefics2 {});
let validation = Validation::new(
workers,
tokenizer,
Some(config),
Some(HubPreprocessorConfig::Idefics2Processor(
Idefics2Preprocessor {
do_image_splitting: true,
},
)),
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
disable_grammar_support,
);
let (encoding, chunks) = match validation
.tokenize(
format!(
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF
),
None,
)
.await
{
Ok(Some((encoding, chunks))) => (encoding, chunks),
_ => panic!("Unexpected tokenization failure"),
};
assert!(
chunks
== vec![
Chunk::Text("test".to_string()).into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
],
"Failed to process images",
);
// Verify the number of fake tokens:
//
// - Two images surrounded/separated by a fake token = 3.
// - Both are split in 5 subimages, separated by a fake token: 2 * 4
//
// Fake tokens get split up by the testing tokenizer, but we don't care.
assert_eq!(
encoding
.get_tokens()
.iter()
.filter(|t| *t == "fake")
.count(),
11
);
}
} }

View File

@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
) )
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass @dataclass
class GPTQWeight: class GPTQWeight:
qweight: torch.Tensor qweight: torch.Tensor

View File

@ -166,35 +166,45 @@ def get_linear(weight, bias, quantize):
elif quantize == "gptq": elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlinLinear,
GPTQMarlinWeight,
)
if not isinstance(weight, GPTQWeight): if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
)
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
weight.bits,
weight.groupsize,
)
else:
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated." f"The passed weight is not `gptq` compatible, loader needs to be updated."
) )
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
)
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
weight.bits,
weight.groupsize,
)
elif quantize == "awq": elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Linear, GPTQMarlin24Linear,
GPTQMarlin24Weight, GPTQMarlin24Weight,
GPTQMarlinLinear,
GPTQMarlinWeight,
MarlinLinear, MarlinLinear,
MarlinWeight, MarlinWeight,
) )
if isinstance(weight, GPTQMarlinWeight): if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear( linear = GPTQMarlin24Linear(
weight=weight, weight=weight,
bias=bias, bias=bias,

View File

@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
try: try:
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize == "gptq"
and gptq_params.quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym
)
def _check_marlin_kernels(): def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0): if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError( raise NotImplementedError(

View File

@ -68,6 +68,9 @@ try:
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.flash_gemma import (
FlashGemma, FlashGemma,
) )
from text_generation_server.models.flash_gemma2 import (
FlashGemma2,
)
from text_generation_server.models.pali_gemma import ( from text_generation_server.models.pali_gemma import (
PaliGemma, PaliGemma,
) )
@ -102,6 +105,7 @@ if FLASH_ATTENTION:
__all__.append(FlashQwen2) __all__.append(FlashQwen2)
__all__.append(FlashStarcoder2) __all__.append(FlashStarcoder2)
__all__.append(FlashGemma) __all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere) __all__.append(FlashCohere)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
@ -145,6 +149,11 @@ class ModelType(enum.Enum):
"name": "Gemma", "name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b", "url": "https://huggingface.co/google/gemma-7b",
} }
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/google/gemma2-9b",
}
COHERE = { COHERE = {
"type": "cohere", "type": "cohere",
"name": "Cohere", "name": "Cohere",
@ -637,6 +646,27 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GEMMA2:
if FLASH_ATTENTION:
return FlashGemma2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -0,0 +1,500 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
class Gemma2Config(PretrainedConfig):
def __init__(
self,
vocab_size=256128,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemma2Attention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
else:
self.window_size = -1
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
# self.softmax_scale = self.head_size**-0.5
self.softmax_scale = config.query_pre_attn_scalar**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
)
# Decode
else:
paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Gemma2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
causal=causal,
is_sliding=is_sliding,
)
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGemma2Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
causal=causal,
is_sliding=layer_id % 2 == 0,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma2Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -375,8 +375,6 @@ class FlashGemmaModel(torch.nn.Module):
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads

View File

@ -39,7 +39,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args: Args:
image_size (`tuple`): image_size (`tuple`):
The size of the input image in the format (width, height). The size of the input image in the format (height, width).
grid_pinpoints (`List`): grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`. of the form `(height, width)`.
@ -47,7 +47,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
The size of each image patch. The size of each image patch.
Returns: Returns:
tuple: The shape of the image patch grid in the format (width, height). tuple: The shape of the image patch grid in the format (height, width).
""" """
if not isinstance(grid_pinpoints, list): if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists") raise ValueError("grid_pinpoints should be a list of tuples or lists")
@ -230,7 +230,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
raise ValueError( raise ValueError(
"The number of patches is not consistent with the image size." "The number of patches is not consistent with the image size."
) )
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx], image_sizes[image_idx],
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,

View File

@ -28,8 +28,12 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
import text_generation_server.models.globals as tgi_globals from text_generation_server.models.globals import (
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS MEM_POOL,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
@ -233,7 +237,8 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0) ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index)) adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
@ -499,9 +504,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get( ADAPTER_TO_INDEX = get_adapter_to_index()
self.requests[idx].adapter_id, 0 adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
)
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
remaining_tokens = ( remaining_tokens = (

View File

@ -0,0 +1,75 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import PretrainedConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashGemma2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PretrainedConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -34,3 +34,8 @@ ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index ADAPTER_TO_INDEX = adapter_to_index
def get_adapter_to_index():
global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX

View File

@ -39,7 +39,9 @@ class PaliGemmaBatch(VlmCausalLMBatch):
# TODO do_convert_RGB should be on by default ? # TODO do_convert_RGB should be on by default ?
image = image.convert("RGB") image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id) full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")

View File

@ -1,3 +1,4 @@
from itertools import repeat
import torch import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -15,6 +16,9 @@ from text_generation_server.models.flash_mistral import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
@ -22,7 +26,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args: Args:
image_size (`tuple`): image_size (`tuple`):
The size of the input image in the format (width, height). The size of the input image in the format (height, width).
grid_pinpoints (`List`): grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`. of the form `(height, width)`.
@ -39,15 +43,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str: def image_text_replacement(processor, image_input, config, image_id: int) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented. image_seq_len = 64
num_features = 320 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
return ( if processor.image_processor.do_image_splitting:
"<fake_token_around_image>" image_str *= 5
+ "<image>" * num_features return image_str
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
@ -64,20 +66,35 @@ def image_text_replacement(image_input, config, image_id) -> str:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
def get_unpadded_features( def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
aspect_ratio: float = width / height aspect_ratio: float = original_width / original_height
current_aspect_ratio: float = current_width / current_height current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio: if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width new_height = (original_height * current_width) // original_width
current_height = new_height padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else: else:
new_width = (width * current_height) // height new_width = (original_width * current_height) // original_height
current_width = new_width padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
newline_features = current_height newline_features = current_height
@ -96,7 +113,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
npatches = image_size // patch_size npatches = image_size // patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape( # Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
[height, width], [height, width],
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
@ -168,9 +187,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if chunk_type == "text": if chunk_type == "text":
full_text += chunk.text full_text += chunk.text
elif chunk_type == "image": elif chunk_type == "image":
full_text += image_text_replacement(image_inputs, config, image_id) full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1 image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)

View File

@ -1,25 +1,15 @@
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
@dataclass
class _GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
class Weights: class Weights:
def __init__( def __init__(
self, self,
@ -212,6 +202,10 @@ class Weights:
""" """
if quantize in ["gptq", "awq"]: if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: try:
qweight = self.get_packed_sharded( qweight = self.get_packed_sharded(
@ -221,17 +215,28 @@ class Weights:
raise RuntimeError( raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized." f"Cannot load `{quantize}` weight, make sure the model is already quantized."
) )
gptq_params = self._get_gptq_params()
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
scales = self.get_packed_sharded( scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes f"{prefix}.scales", dim=1, block_sizes=block_sizes
) )
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
g_idx = self.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if quantize == "gptq" and gptq_params.quant_method == "gptq": if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and gptq_params.quant_method == "awq": elif quantize == "gptq" and gptq_params.quant_method == "awq":
@ -269,7 +274,6 @@ class Weights:
repack_gptq_for_marlin, repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24: if is_marlin_24:
B = self.get_packed_sharded( B = self.get_packed_sharded(
@ -286,31 +290,6 @@ class Weights:
weight = GPTQMarlin24Weight( weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
) )
elif quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else: else:
B = self.get_packed_sharded( B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes f"{prefix}.B", dim=1, block_sizes=block_sizes
@ -356,6 +335,10 @@ class Weights:
raise ValueError("get_multi_weights_col is not supported for exl2") raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]: elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: try:
qweight = torch.cat( qweight = torch.cat(
@ -366,14 +349,31 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized" f"Cannot load `{quantize}` weight, make sure the model is already quantized"
) )
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat( scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
gptq_params = self._get_gptq_params() gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA from text_generation_server.layers.gptq import HAS_EXLLAMA
@ -425,10 +425,8 @@ class Weights:
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight, GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24: if is_marlin_24:
try: try:
@ -452,36 +450,6 @@ class Weights:
weight = GPTQMarlin24Weight( weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
) )
elif quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else: else:
try: try:
B = torch.cat( B = torch.cat(
@ -544,9 +512,41 @@ class Weights:
) )
elif quantize == "gptq": elif quantize == "gptq":
use_exllama = True from text_generation_server.layers.marlin import (
gptq_params = self._get_gptq_params() can_use_gptq_marlin,
repack_gptq_for_marlin,
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if gptq_params.bits != 4: if gptq_params.bits != 4:
use_exllama = False use_exllama = False
@ -672,10 +672,8 @@ class Weights:
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight, GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24: if is_marlin_24:
try: try:
@ -698,35 +696,6 @@ class Weights:
weight = GPTQMarlin24Weight( weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
) )
elif quant_method == "gptq":
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
else: else:
try: try:
B = self.get_sharded(f"{prefix}.B", dim=0) B = self.get_sharded(f"{prefix}.B", dim=0)
@ -743,18 +712,17 @@ class Weights:
else: else:
s = self.get_sharded(f"{prefix}.s", dim=0) s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s) weight = MarlinWeight(B=B, s=s)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_params(self) -> _GPTQParams: def _get_gptq_params(self) -> GPTQParams:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None) checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False desc_act = False
sym = True sym = False
quant_method = "gptq" quant_method = "gptq"
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
@ -767,7 +735,7 @@ class Weights:
except Exception: except Exception:
raise e raise e
return _GPTQParams( return GPTQParams(
bits=bits, bits=bits,
checkpoint_format=checkpoint_format, checkpoint_format=checkpoint_format,
desc_act=desc_act, desc_act=desc_act,