From 70217ac3454396d9a08a25ce1aa8b40a1fe87069 Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Thu, 29 May 2025 15:58:24 +0800 Subject: [PATCH 1/9] [Gaudi] Fix the OOM issue of Llama-4-Scout-17B-16E-Instruct (#3245) Signed-off-by: yuanwu --- .../models/custom_modeling/flash_llama_modeling.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0edea03a..dfb16621 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -143,12 +143,14 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads = getattr( config, "num_key_value_heads", config.num_attention_heads ) - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + + if config.model_type != "llama4_text": + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr( From 6b6e30a6f680b45884cdd144d57ac5115d69090e Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Thu, 29 May 2025 17:38:44 +0800 Subject: [PATCH 2/9] [gaudi] Fix the Llama-4-Maverick-17B-128E crash issue (#3246) Signed-off-by: yuanwu --- .../models/custom_modeling/flash_llama4_modeling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 11864c52..0e3af85a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -48,7 +48,6 @@ from text_generation_server.layers.attention import ( ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaAttention, - LlamaMLP, ) @@ -444,7 +443,7 @@ class Llama4TextDecoderLayer(nn.Module): if self.is_moe_layer: # the 128E model interleaves dense / sparse self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights) else: - self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) + self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", From 249189d96e2c5aa6197221cc67a5b649a2370ade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 30 May 2025 16:16:36 +0200 Subject: [PATCH 3/9] Prepare for 3.3.2 (#3249) --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 2 +- README.md | 6 +++--- docs/openapi.json | 2 +- docs/source/backends/gaudi.mdx | 10 +++++----- docs/source/backends/neuron.md | 2 +- .../source/basic_tutorials/gated_model_access.md | 2 +- docs/source/conceptual/quantization.md | 6 +++--- docs/source/installation_amd.md | 2 +- docs/source/installation_intel.md | 4 ++-- docs/source/installation_nvidia.md | 2 +- docs/source/quicktour.md | 4 ++-- docs/source/reference/api_reference.md | 2 +- .../test_flash_gemma3_image_base64_rgb_jpg.json | 2 +- .../test_flash_gemma3_image_base64_rgb_png.json | 2 +- .../test_flash_gemma3_image_base64_rgba.json | 2 +- .../test_flash_gemma3_image_cow.json | 2 +- .../test_flash_gemma3_image_cow_dog.json | 2 +- .../test_json_schema_basic.json | 2 +- .../test_json_schema_complex.json | 2 +- .../test_mllama/test_mllama_load.json | 4 ++-- .../test_mllama/test_mllama_simpl.json | 2 +- 22 files changed, 40 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b09f1c3f..c4b2572f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4650,7 +4650,7 @@ dependencies = [ [[package]] name = "text-generation-backends-trtllm" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "async-trait", "clap 4.5.32", @@ -4671,7 +4671,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "average", "clap 4.5.32", @@ -4691,7 +4691,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -4709,7 +4709,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "clap 4.5.32", "ctrlc", @@ -4730,7 +4730,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "anyhow", "async-stream", @@ -4782,7 +4782,7 @@ dependencies = [ [[package]] name = "text-generation-router-llamacpp" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "async-trait", "bindgen 0.71.1", @@ -4800,7 +4800,7 @@ dependencies = [ [[package]] name = "text-generation-router-v2" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4849,7 +4849,7 @@ dependencies = [ [[package]] name = "text-generation-router-v3" -version = "3.3.1-dev0" +version = "3.3.2-dev0" dependencies = [ "async-stream", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index f7b1e3b7..06dc251b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ default-members = [ resolver = "2" [workspace.package] -version = "3.3.1-dev0" +version = "3.3.2-dev0" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/README.md b/README.md index f4c6c562..5586e0c7 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model ``` And then you can make requests like @@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \ **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. -**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1-rocm --model-id $model` instead of the command above. +**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2-rocm --model-id $model` instead of the command above. To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): ``` @@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading token= docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model ``` ### A note on Shared Memory (shm) diff --git a/docs/openapi.json b/docs/openapi.json index 9249acad..ff63c3da 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "3.3.1-dev0" + "version": "3.3.2-dev0" }, "paths": { "/": { diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx index ab882fc2..49c6739d 100644 --- a/docs/source/backends/gaudi.mdx +++ b/docs/source/backends/gaudi.mdx @@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN docker run --runtime=habana --cap-add=sys_nice --ipc=host \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ - ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model ``` @@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN docker run --runtime=habana --cap-add=sys_nice --ipc=host \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ - ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model ``` @@ -115,7 +115,7 @@ docker run -p 8080:80 \ -e BATCH_BUCKET_SIZE=256 \ -e PREFILL_BATCH_BUCKET_SIZE=4 \ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \ - ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model \ --sharded true --num-shard 8 \ --max-input-tokens 1024 --max-total-tokens 2048 \ @@ -141,7 +141,7 @@ docker run -p 8080:80 \ -v $volume:/data \ -e PREFILL_BATCH_BUCKET_SIZE=1 \ -e BATCH_BUCKET_SIZE=1 \ - ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model \ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \ --max-total-tokens 8192 --max-batch-size 4 @@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \ -e PROF_PATH=/tmp/hpu_profile \ -e PROF_RANKS=0 \ -e PROF_RECORD_SHAPES=True \ - ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model ``` diff --git a/docs/source/backends/neuron.md b/docs/source/backends/neuron.md index a1fa3a9e..10c8a4fd 100644 --- a/docs/source/backends/neuron.md +++ b/docs/source/backends/neuron.md @@ -31,7 +31,7 @@ deployment instructions in the model card: The service is launched simply by running the text-generation-inference container with two sets of parameters: ``` -docker run ghcr.io/huggingface/text-generation-inference:3.3.1-neuron +docker run ghcr.io/huggingface/text-generation-inference:3.3.2-neuron ``` - system parameters are used to map ports, volumes and devices between the host and the service, diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index dfed553e..50c71ab5 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -19,6 +19,6 @@ docker run --gpus all \ --shm-size 1g \ -e HF_TOKEN=$token \ -p 8080:80 \ - -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 \ + -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 \ --model-id $model ``` diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index c215f4c3..a666a48a 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes ``` 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. @@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes-nf4 +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes-nf4 ``` You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). @@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$ TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize gptq +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize gptq ``` Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 9f92859c..19fbe8ba 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1-rocm \ + ghcr.io/huggingface/text-generation-inference:3.3.2-rocm \ --model-id $model ``` diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md index 71c8a2de..c1a2e867 100644 --- a/docs/source/installation_intel.md +++ b/docs/source/installation_intel.md @@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1-intel-xpu \ + ghcr.io/huggingface/text-generation-inference:3.3.2-intel-xpu \ --model-id $model --cuda-graphs 0 ``` @@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1-intel-cpu \ + ghcr.io/huggingface/text-generation-inference:3.3.2-intel-cpu \ --model-id $model --cuda-graphs 0 ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 40ae145b..3aede5a9 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1 \ + ghcr.io/huggingface/text-generation-inference:3.3.2 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 76832317..f1d2c92a 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.1 \ + ghcr.io/huggingface/text-generation-inference:3.3.2 \ --model-id $model ``` @@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:3.3.1 --help +docker run ghcr.io/huggingface/text-generation-inference:3.3.2 --help ``` diff --git a/docs/source/reference/api_reference.md b/docs/source/reference/api_reference.md index 8dbe977a..5830f7b9 100644 --- a/docs/source/reference/api_reference.md +++ b/docs/source/reference/api_reference.md @@ -163,7 +163,7 @@ hub = { # create Hugging Face Model Class huggingface_model = HuggingFaceModel( - image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.1"), + image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.2"), env=hub, role=role, ) diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json index df9daac8..0c02702e 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 42, "prompt_tokens": 277, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json index 328105ca..0bb67dfb 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 62, "prompt_tokens": 277, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json index b7918d48..dc1309d2 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 67, "prompt_tokens": 277, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json index 43d01863..7f7d0ef6 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 72, "prompt_tokens": 275, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json index 9d80a763..35ca9cf0 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 80, "prompt_tokens": 279, diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json index 30241eb9..c93f8a67 100644 --- a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json @@ -14,7 +14,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 35, "prompt_tokens": 32, diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json index 008ae5b0..326d6702 100644 --- a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json @@ -14,7 +14,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 44, "prompt_tokens": 37, diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json index 50e75361..682e10d4 100644 --- a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json @@ -18,7 +18,7 @@ "id": "", "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, @@ -44,7 +44,7 @@ "id": "", "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json index 91297113..c3c5e76b 100644 --- a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json @@ -17,7 +17,7 @@ "id": "", "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, From 1ff9d185d533b509ee43a025a36663984197478a Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Tue, 3 Jun 2025 19:42:29 +0800 Subject: [PATCH 4/9] Remove useless packages (#3253) Signed-off-by: yuanwu --- backends/gaudi/server/poetry.lock | 250 ------------------------- backends/gaudi/server/requirements.txt | 14 -- 2 files changed, 264 deletions(-) diff --git a/backends/gaudi/server/poetry.lock b/backends/gaudi/server/poetry.lock index b9b2e138..c6cace66 100644 --- a/backends/gaudi/server/poetry.lock +++ b/backends/gaudi/server/poetry.lock @@ -1058,199 +1058,6 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -description = "CUBLAS native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"}, -] - -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -description = "CUDA profiling tools runtime libs." -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"}, -] - -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -description = "NVRTC native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"}, -] - -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -description = "CUDA Runtime native Libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"}, -] - -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -description = "cuDNN runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, - {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" - -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -description = "CUFFT native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"}, -] - -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -description = "CURAND native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"}, -] - -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -description = "CUDA solver native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" -nvidia-cusparse-cu12 = "*" -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -description = "CUSPARSE native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"}, -] - -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-cusparselt-cu12" -version = "0.6.2" -description = "NVIDIA cuSPARSELt" -optional = false -python-versions = "*" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8"}, - {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9"}, - {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70"}, -] - -[[package]] -name = "nvidia-nccl-cu12" -version = "2.21.5" -description = "NVIDIA Collective Communication Library (NCCL) Runtime" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"}, -] - -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -description = "Nvidia JIT LTO Library" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, -] - -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -description = "NVIDIA Tools Extension" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, -] - [[package]] name = "opentelemetry-api" version = "1.32.0" @@ -2650,63 +2457,6 @@ files = [ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] -[[package]] -name = "torch" -version = "2.6.0" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -optional = false -python-versions = ">=3.9.0" -groups = ["main"] -files = [ - {file = "torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961"}, - {file = "torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab"}, - {file = "torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341"}, - {file = "torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628"}, - {file = "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1"}, - {file = "torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d"}, - {file = "torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7"}, - {file = "torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21"}, - {file = "torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9"}, - {file = "torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb"}, - {file = "torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239"}, - {file = "torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989"}, - {file = "torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf"}, - {file = "torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b"}, - {file = "torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc"}, - {file = "torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"}, - {file = "torch-2.6.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053"}, - {file = "torch-2.6.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e"}, - {file = "torch-2.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a"}, - {file = "torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c"}, -] - -[package.dependencies] -filelock = "*" -fsspec = "*" -jinja2 = "*" -networkx = "*" -nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparselt-cu12 = {version = "0.6.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -setuptools = {version = "*", markers = "python_version >= \"3.12\""} -sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""} -triton = {version = "3.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -typing-extensions = ">=4.10.0" - -[package.extras] -opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.13.0)"] - [[package]] name = "tqdm" version = "4.67.1" diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt index 1a5d767f..6f897722 100644 --- a/backends/gaudi/server/requirements.txt +++ b/backends/gaudi/server/requirements.txt @@ -36,19 +36,6 @@ nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13" networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13" @@ -88,7 +75,6 @@ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13" sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13" -torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13" triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" From 79183d164728f080e1a571b7ff1f58bd0ed840b0 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 10 Jun 2025 17:56:25 +0200 Subject: [PATCH 5/9] Bump neuron SDK version (#3260) * chore(neuron): bump version to 0.2.0 * refactor(neuron): use named parameters in inputs helpers This allows to hide the differences between the two backends in terms of input parameters. * refactor(neuron): remove obsolete code paths * fix(neuron): use neuron_config whenever possible * fix(neuron): use new cache import path * fix(neuron): neuron config is not stored in config anymore * fix(nxd): adapt model retrieval to new APIs * fix(generator): emulate greedy in sampling parameters When on-device sampling is enabled, we need to emulate the greedy behaviour using top-k=1, top-p=1, temperature=1. * test(neuron): update models and expectations * feat(neuron): support on-device sampling * fix(neuron): adapt entrypoint * tests(neuron): remove obsolete models * fix(neuron): adjust test expectations for llama on nxd --- Dockerfile.neuron | 21 ++- .../text_generation_server/generator.py | 140 ++++++++++------- .../server/text_generation_server/model.py | 51 +++--- .../text_generation_server}/tgi_env.py | 145 ++++++++++-------- backends/neuron/tests/fixtures/model.py | 74 ++------- .../neuron/tests/server/test_cached_model.py | 42 +++++ .../tests/server/test_continuous_batching.py | 4 +- backends/neuron/tests/server/test_decode.py | 12 +- backends/neuron/tests/server/test_prefill.py | 26 ++-- backends/neuron/tests/test_entry_point.py | 63 ++++++++ backends/neuron/tgi-entrypoint.sh | 2 +- backends/neuron/tgi_entry_point.py | 53 +++++++ .../fixtures/neuron/export_models.py | 18 --- integration-tests/neuron/test_generate.py | 6 +- 14 files changed, 393 insertions(+), 264 deletions(-) rename backends/neuron/{ => server/text_generation_server}/tgi_env.py (63%) mode change 100755 => 100644 create mode 100644 backends/neuron/tests/server/test_cached_model.py create mode 100644 backends/neuron/tests/test_entry_point.py create mode 100755 backends/neuron/tgi_entry_point.py diff --git a/Dockerfile.neuron b/Dockerfile.neuron index d22ca222..6228dbb7 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -5,7 +5,7 @@ RUN mkdir -p /tgi # Fetch the optimum-neuron sources directly to avoid relying on pypi deployments FROM alpine AS optimum-neuron RUN mkdir -p /optimum-neuron -ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz +ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1 # Build cargo components (adapted from TGI original Dockerfile) @@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU # Install neuronx packages RUN apt-get update -y \ && apt-get install -y --no-install-recommends \ - aws-neuronx-dkms=2.19.64.0 \ - aws-neuronx-collectives=2.23.135.0-3e70920f2 \ - aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \ - aws-neuronx-tools=2.20.204.0 \ + aws-neuronx-dkms=2.20.28.0 \ + aws-neuronx-collectives=2.24.59.0-838c7fc8b \ + aws-neuronx-runtime-lib=2.24.53.0-f239092cc \ + aws-neuronx-tools=2.22.61.0 \ libxml2 \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -125,11 +125,10 @@ RUN pip3 install \ --index-url https://download.pytorch.org/whl/cpu RUN pip3 install \ - neuronx-cc==2.16.372.0 \ - torch-neuronx==2.5.1.2.4.0 \ - transformers-neuronx==0.13.322 \ - neuronx-distributed==0.10.1 \ - libneuronxla==2.1.681.0 \ + neuronx-cc==2.17.194.0 \ + torch-neuronx==2.5.1.2.6.0 \ + neuronx-distributed==0.11.0 \ + libneuronxla==2.2.1630.0 \ --extra-index-url=https://pip.repos.neuron.amazonaws.com # Install HuggingFace packages @@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz # Final image FROM neuron -COPY backends/neuron/tgi_env.py /tgi_env.py +COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index b3887e14..10a4d7a2 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -7,7 +7,8 @@ from typing import List, Optional, Tuple import torch from loguru import logger -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from optimum.neuron.configuration_utils import NeuronConfig from transformers.generation import GenerationConfig from optimum.neuron import NeuronModelForCausalLM @@ -175,6 +176,12 @@ class Slot: self._generation_config.top_p = request.parameters.top_p if request.parameters.typical_p != 0: self._generation_config.typical_p = request.parameters.typical_p + else: + # Set the sampling parameters to emulate greedy decoding when using on-device sampling + self._generation_config.temperature = 1.0 + self._generation_config.top_k = 1 + self._generation_config.top_p = 1.0 + self._generation_config.typical_p = 1.0 if request.parameters.repetition_penalty != 0: self._generation_config.repetition_penalty = ( request.parameters.repetition_penalty @@ -211,19 +218,11 @@ class Slot: self._mask = attention_mask.clone() self._selector = selector - def pause(self, reset_on_pause: bool): + def pause(self): """Mark the current slot as paused for generation. Note that the KV cache for this slot will still be filled. """ - if reset_on_pause: - # Drop the last token as it will be added back when resuming the slot - self._generated_tokens -= 1 - # Since generated tokens are now part of the prefill, we need to reevaluate - # max_new_tokens for the next generation - self._generation_config.max_new_tokens = ( - self._max_new_tokens - self._generated_tokens - ) self._state = Slot.State.PAUSE def resume(self): @@ -340,16 +339,27 @@ class NeuronGenerator(Generator): tokenizer: PreTrainedTokenizerBase, ): self.model = model - self.rebuild_cache_on_prefill = not self.model.continuous_batching + if not isinstance(self.model, NeuronModelForCausalLM): + raise ValueError("The model must be a NeuronModelForCausalLM.") + if not model.neuron_config.continuous_batching: + raise ValueError( + "The neuron model must be compiled with continuous_batching=True." + ) # Specify padding and truncation options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] + self.slots = [ + Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size) + ] self.batch_id = 0 + @property + def on_device_sampling(self) -> bool: + return getattr(self.model.neuron_config, "on_device_sampling", False) + @property def info(self) -> InfoResponse: """Returns the expected InfoResponse.""" @@ -371,14 +381,22 @@ class NeuronGenerator(Generator): The maximum number of tokens the model supports. """ # Just check that the warmup request parameters match the model capacity - batch_size = self.model.batch_size + batch_size = self.model.neuron_config.batch_size if len(batch.requests) > batch_size: raise ValueError( - f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." + f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." ) self.prefill(batch) self.clear() - return self.model.batch_size * self.model.max_length + return ( + self.model.neuron_config.batch_size + * self.model.neuron_config.sequence_length + ) + + def max_prefill_length(self) -> int: + if hasattr(self.model.neuron_config, "max_context_length"): + return self.model.neuron_config.max_context_length + return self.model.neuron_config.sequence_length def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill new requests. @@ -398,7 +416,7 @@ class NeuronGenerator(Generator): if len(empty_slots) < len(batch.requests): raise ValueError( f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." - f" Please align max_batch_size with the static batch size: {self.model.batch_size}." + f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}." ) # Assign each request to an empty slot logger.debug( @@ -412,14 +430,8 @@ class NeuronGenerator(Generator): logger.debug( f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" ) - if self.rebuild_cache_on_prefill: - # We will clear pending slots and prefill all slots - prefill_slots = self.slots - seq_ids = None - else: - # We only need to pass inputs for the new requests - prefill_slots = new_slots - seq_ids = torch.tensor([slot.id for slot in prefill_slots]) + prefill_slots = new_slots + seq_ids = torch.tensor([slot.id for slot in prefill_slots]) # Reconstruct the full inputs (without padding) as seen by the model. # This comprises: # - the inputs for new requests, @@ -431,8 +443,10 @@ class NeuronGenerator(Generator): inputs.append(slot.cached_text) # Apply truncation, making sure we fit into static dimensions if slot.truncate == 0: - max_length = self.model.max_length - elif slot.truncate > max_length and slot.truncate < self.model.max_length: + max_length = self.max_prefill_length() + elif ( + slot.truncate > max_length and slot.truncate < self.max_prefill_length() + ): max_length = slot.truncate # Tokenize with padding and truncation padded_inputs = self.tokenizer( @@ -444,13 +458,12 @@ class NeuronGenerator(Generator): ) input_ids = padded_inputs.input_ids attention_mask = padded_inputs.attention_mask + sampling_params = ( + torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None + ) # Pause previously active slots during generation - next_tokens = [] for slot in active_slots: - slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) - if self.rebuild_cache_on_prefill: - # The slot will be reset, so we need to store its next token - next_tokens.append(slot.next_token) + slot.pause() # Each slot must be reset with the padded inputs and masks for i, slot in enumerate(prefill_slots): if slot.state != slot.state.EMPTY: @@ -464,29 +477,33 @@ class NeuronGenerator(Generator): slot_input_ids, slot.generation_config, self.model, - self.model.max_length, + self.model.neuron_config.sequence_length, tokenizer=self.tokenizer, seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill( - input_ids, attention_mask, seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, ) - logits = self.model(**model_inputs)[0] + tokens_or_logits = self.model(**model_inputs)[0] generation, next_batch = self._generate_token( - prefill_slots, self.batch_id, logits, input_ids + prefill_slots, self.batch_id, tokens_or_logits, input_ids ) self.batch_id += 1 # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() - if self.rebuild_cache_on_prefill: - # Append back the next token - slot.append(next_tokens[i]) logger.debug("Model ready for decoding") if next_batch is not None: logger.debug( @@ -530,12 +547,8 @@ class NeuronGenerator(Generator): raise ValueError( "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" ) - if self.model.continuous_batching: - decode_slots = active_slots - seq_ids = torch.tensor([slot.id for slot in decode_slots]) - else: - decode_slots = self.slots - seq_ids = None + decode_slots = active_slots + seq_ids = torch.tensor([slot.id for slot in decode_slots]) # Reconstruct input_ids and attention_mask from decode slots n_slots = len(decode_slots) input_ids = torch.full( @@ -545,22 +558,32 @@ class NeuronGenerator(Generator): for slot in decode_slots: max_length = max(max_length, slot.attention_mask.size(-1)) attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) + sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None for i, slot in enumerate(decode_slots): if slot.state != Slot.State.EMPTY: # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature model_inputs = self.model.prepare_inputs_for_decode( - input_ids, attention_mask, seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, + ) + tokens_or_logits = self.model(**model_inputs)[0] + return self._generate_token( + decode_slots, next_batch_id, tokens_or_logits, input_ids ) - logits = self.model(**model_inputs)[0] - return self._generate_token(decode_slots, next_batch_id, logits, input_ids) def _generate_token( self, slots: List[Slot], next_batch_id: int, - logits: torch.Tensor, + tokens_or_logits: torch.Tensor, input_ids: torch.LongTensor, ) -> Tuple[List[Generation], CachedBatch]: generations = [] @@ -569,9 +592,12 @@ class NeuronGenerator(Generator): if slot.state != Slot.State.READY: continue request_id = slot.request_id - next_token_logits = logits[i : i + 1, -1, :] slot_input_ids = input_ids[i : i + 1, :] - next_token = slot.select(slot_input_ids, next_token_logits) + if self.on_device_sampling: + next_token = tokens_or_logits[i] + else: + next_token_logits = tokens_or_logits[i : i + 1, -1, :] + next_token = slot.select(slot_input_ids, next_token_logits) next_token_text = slot.append(next_token) generated_text = None finish_reason = None @@ -622,7 +648,7 @@ class NeuronGenerator(Generator): def _cached_batch(self, batch_id: int, request_ids: List): size = len(request_ids) - max_tokens = size * self.model.max_length + max_tokens = size * self.model.neuron_config.sequence_length return CachedBatch( id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens ) @@ -671,8 +697,16 @@ class NeuronGenerator(Generator): Returns: A NeuronGenerator. """ - config = AutoConfig.from_pretrained(model_id) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None start = time.time() if neuron_config is None: export_kwargs = get_export_kwargs_from_env() diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py index 2151a921..d281b175 100644 --- a/backends/neuron/server/text_generation_server/model.py +++ b/backends/neuron/server/text_generation_server/model.py @@ -6,10 +6,12 @@ from typing import Optional from huggingface_hub import snapshot_download from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger -from transformers import AutoConfig -from optimum.neuron import NeuronModelForCausalLM -from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig + + +from .tgi_env import check_env_and_neuron_config_compatibility def get_export_kwargs_from_env(): @@ -24,7 +26,6 @@ def get_export_kwargs_from_env(): num_cores = int(num_cores) auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) return { - "task": "text-generation", "batch_size": batch_size, "sequence_length": sequence_length, "num_cores": num_cores, @@ -32,20 +33,15 @@ def get_export_kwargs_from_env(): } -def is_cached(model_id, neuron_config): +def is_cached(model_id): # Look for cached entries for the specified model in_cache = False - entries = get_hub_cached_entries(model_id, "inference") + entries = get_hub_cached_entries(model_id) # Look for compatible entries for entry in entries: - compatible = True - for key, value in neuron_config.items(): - # Only weights can be different - if key in ["checkpoint_id", "checkpoint_revision"]: - continue - if entry[key] != value: - compatible = False - if compatible: + if check_env_and_neuron_config_compatibility( + entry, check_compiler_version=True + ): in_cache = True break return in_cache @@ -87,8 +83,16 @@ def fetch_model( revision = None # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Note that the model may already be present in the cache. - config = AutoConfig.from_pretrained(model_id, revision=revision) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None if neuron_config is not None: if os.path.isdir(model_id): return model_id @@ -99,16 +103,11 @@ def fetch_model( log_cache_size() return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") # Model needs to be exported: look for compatible cached entries on the hub - export_kwargs = get_export_kwargs_from_env() - export_config = NeuronModelForCausalLM.get_export_config( - model_id, config, revision=revision, **export_kwargs - ) - neuron_config = export_config.neuron - if not is_cached(model_id, neuron_config): + if not is_cached(model_id): hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" error_msg = ( - f"No cached version found for {model_id} with {neuron_config}." + f"No cached version found for {model_id} with {get_export_kwargs_from_env()}." f"You can start a discussion to request it on {hub_cache_url}" f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" ) @@ -121,8 +120,10 @@ def fetch_model( # Prefetch weights, tokenizer and generation config so that they are in cache log_cache_size() start = time.time() - snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") + snapshot_path = snapshot_download( + model_id, revision=revision, ignore_patterns="*.bin" + ) end = time.time() logger.info(f"Model weights fetched in {end - start:.2f} s.") log_cache_size() - return model_id + return snapshot_path diff --git a/backends/neuron/tgi_env.py b/backends/neuron/server/text_generation_server/tgi_env.py old mode 100755 new mode 100644 similarity index 63% rename from backends/neuron/tgi_env.py rename to backends/neuron/server/text_generation_server/tgi_env.py index a7042130..ee97f180 --- a/backends/neuron/tgi_env.py +++ b/backends/neuron/server/text_generation_server/tgi_env.py @@ -6,12 +6,11 @@ import os import sys from typing import Any, Dict, List, Optional -from huggingface_hub import constants -from transformers import AutoConfig - from optimum.neuron.modeling_decoder import get_available_cores -from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig from optimum.neuron.utils.version_utils import get_neuronxcc_version +from optimum.neuron.utils import map_torch_dtype logger = logging.getLogger(__name__) @@ -24,15 +23,9 @@ tgi_router_env_vars = [ ] tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"] -env_config_peering = [ - ("MAX_BATCH_SIZE", "batch_size"), - ("MAX_TOTAL_TOKENS", "sequence_length"), - ("HF_AUTO_CAST_TYPE", "auto_cast_type"), - ("HF_NUM_CORES", "num_cores"), -] # By the end of this script all env var should be specified properly -env_vars = tgi_server_env_vars + tgi_router_env_vars +tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars available_cores = get_available_cores() neuronxcc_version = get_neuronxcc_version() @@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace: def neuron_config_to_env(neuron_config): + if isinstance(neuron_config, NeuronConfig): + neuron_config = neuron_config.to_dict() with open(os.environ["ENV_FILEPATH"], "w") as f: - for env_var, config_key in env_config_peering: - f.write("export {}={}\n".format(env_var, neuron_config[config_key])) + f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"])) + f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"])) + f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"])) + config_key = ( + "auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype" + ) + auto_cast_type = neuron_config[config_key] + f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type)) max_input_tokens = os.getenv("MAX_INPUT_TOKENS") if not max_input_tokens: max_input_tokens = int(neuron_config["sequence_length"]) // 2 @@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config): def sort_neuron_configs(dictionary): - return -dictionary["num_cores"], -dictionary["batch_size"] + return -dictionary["tp_degree"], -dictionary["batch_size"] def lookup_compatible_cached_model( @@ -119,7 +120,7 @@ def lookup_compatible_cached_model( ) -> Optional[Dict[str, Any]]: # Reuse the same mechanic as the one in use to configure the tgi server part # The only difference here is that we stay as flexible as possible on the compatibility part - entries = get_hub_cached_entries(model_id, "inference") + entries = get_hub_cached_entries(model_id) logger.debug( "Found %d cached entries for model %s, revision %s", @@ -155,15 +156,15 @@ def lookup_compatible_cached_model( def check_env_and_neuron_config_compatibility( - neuron_config: Dict[str, Any], check_compiler_version: bool + neuron_config_dict: Dict[str, Any], check_compiler_version: bool ) -> bool: logger.debug( "Checking the provided neuron config %s is compatible with the local setup and provided environment", - neuron_config, + neuron_config_dict, ) # Local setup compat checks - if neuron_config["num_cores"] > available_cores: + if neuron_config_dict["tp_degree"] > available_cores: logger.debug( "Not enough neuron cores available to run the provided neuron config" ) @@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility( if ( check_compiler_version - and neuron_config["compiler_version"] != neuronxcc_version + and neuron_config_dict["neuronxcc_version"] != neuronxcc_version ): logger.debug( "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)", neuronxcc_version, - neuron_config["compiler_version"], + neuron_config_dict["neuronxcc_version"], ) return False - for env_var, config_key in env_config_peering: - neuron_config_value = str(neuron_config[config_key]) - env_value = os.getenv(env_var, str(neuron_config_value)) + batch_size = os.getenv("MAX_BATCH_SIZE", None) + if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size): + logger.debug( + "The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)", + os.getenv("MAX_BATCH_SIZE"), + neuron_config_dict["batch_size"], + ) + return False + max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None) + if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int( + max_total_tokens + ): + logger.debug( + "The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)", + max_total_tokens, + neuron_config_dict["sequence_length"], + ) + return False + num_cores = os.getenv("HF_NUM_CORES", None) + if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores): + logger.debug( + "The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)", + num_cores, + neuron_config_dict["tp_degree"], + ) + return False + auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None) + if auto_cast_type is not None: + config_key = ( + "auto_cast_type" + if "auto_cast_type" in neuron_config_dict + else "torch_dtype" + ) + neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key])) + env_value = map_torch_dtype(auto_cast_type) if env_value != neuron_config_value: logger.debug( - "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)", - env_var, - config_key, + "The provided auto cast type and the neuron config param differ (%s != %s)", env_value, neuron_config_value, ) return False - max_input_tokens = int( os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)) ) if max_input_tokens > 0: - sequence_length = neuron_config["sequence_length"] + if hasattr(neuron_config_dict, "max_context_length"): + sequence_length = neuron_config_dict["max_context_length"] + else: + sequence_length = neuron_config_dict["sequence_length"] if max_input_tokens >= sequence_length: logger.debug( "Specified max input tokens is not compatible with config sequence length ( %s >= %s)", @@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility( def get_env_dict() -> Dict[str, str]: d = {} - for k in env_vars: + for k in tgi_env_vars: d[k] = os.getenv(k) return d -def main(): - """ - This script determines proper default TGI env variables for the neuron precompiled models to - work properly - :return: - """ - args = parse_cmdline_and_set_env() - - for env_var in env_vars: - if not os.getenv(env_var): - break - else: - logger.info( - "All env vars %s already set, skipping, user know what they are doing", - env_vars, +def get_neuron_config_for_model( + model_name_or_path: str, revision: Optional[str] = None +) -> NeuronConfig: + try: + neuron_config = NeuronConfig.from_pretrained( + model_name_or_path, revision=revision ) - sys.exit(0) - - cache_dir = constants.HF_HUB_CACHE - - logger.info("Cache dir %s, model %s", cache_dir, args.model_id) - - config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) - neuron_config = getattr(config, "neuron", None) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_name_or_path, + revision, + e, + ) + neuron_config = None if neuron_config is not None: compatible = check_env_and_neuron_config_compatibility( - neuron_config, check_compiler_version=False + neuron_config.to_dict(), check_compiler_version=False ) if not compatible: env_dict = get_env_dict() @@ -252,17 +276,6 @@ def main(): logger.error(msg) raise Exception(msg) else: - neuron_config = lookup_compatible_cached_model(args.model_id, args.revision) + neuron_config = lookup_compatible_cached_model(model_name_or_path, revision) - if not neuron_config: - msg = ( - "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}" - ).format(get_env_dict(), available_cores, neuronxcc_version) - logger.error(msg) - raise Exception(msg) - - neuron_config_to_env(neuron_config) - - -if __name__ == "__main__": - main() + return neuron_config diff --git a/backends/neuron/tests/fixtures/model.py b/backends/neuron/tests/fixtures/model.py index 4b6a1375..ad41fd10 100644 --- a/backends/neuron/tests/fixtures/model.py +++ b/backends/neuron/tests/fixtures/model.py @@ -4,14 +4,12 @@ import subprocess import sys from tempfile import TemporaryDirectory -import huggingface_hub +import os import pytest from transformers import AutoTokenizer -from optimum.neuron import NeuronModelForCausalLM -from optimum.neuron.utils import synchronize_hub_cache -from optimum.neuron.version import __sdk_version__ as sdk_version -from optimum.neuron.version import __version__ as version + +from optimum.neuron.cache import synchronize_hub_cache logging.basicConfig( @@ -21,30 +19,14 @@ logging.basicConfig( ) logger = logging.getLogger(__file__) + OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache" + # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { - "gpt2": { - "model_id": "gpt2", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 1024, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, "llama": { - "model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 2048, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, - "mistral": { - "model_id": "optimum/mistral-1.1b-testing", + "model_id": "unsloth/Llama-3.2-1B-Instruct", "export_kwargs": { "batch_size": 4, "sequence_length": 4096, @@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = { "batch_size": 4, "sequence_length": 4096, "num_cores": 2, - "auto_cast_type": "fp16", + "auto_cast_type": "bf16", }, }, "granite": { @@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = { } -def get_hub_neuron_model_id(config_name: str): - return ( - f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}" - ) - - def export_model(model_id, export_kwargs, neuron_model_path): export_command = [ "optimum-cli", @@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path): def neuron_model_config(request): """Expose a pre-trained neuron model - The fixture first makes sure the following model artifacts are present on the hub: - - exported neuron model under optimum-internal-testing/neuron-testing--, - - cached artifacts under optimum-internal-testing/neuron-testing-cache. - If not, it will export the model and push it to the hub. - - It then fetches the model locally and return a dictionary containing: + The fixture exports a model locally and returns a dictionary containing: - a configuration name, - the original model id, - the export parameters, - - the neuron model id, - the neuron model local path. For each exposed model, the local directory is maintained for the duration of the test session and cleaned up afterwards. - The hub model artifacts are never cleaned up and persist accross sessions. - They must be cleaned up manually when the optimum-neuron version changes. """ config_name = request.param model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) model_id = model_config["model_id"] export_kwargs = model_config["export_kwargs"] - neuron_model_id = get_hub_neuron_model_id(config_name) with TemporaryDirectory() as neuron_model_path: - hub = huggingface_hub.HfApi() - if hub.repo_exists(neuron_model_id): - logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") - hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) - else: - export_model(model_id, export_kwargs, neuron_model_path) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.save_pretrained(neuron_model_path) - del tokenizer - # Create the test model on the hub - hub.create_repo(neuron_model_id, private=True) - hub.upload_folder( - folder_path=neuron_model_path, - repo_id=neuron_model_id, - ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"], - ) - # Make sure it is cached - synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID) + export_model(model_id, export_kwargs, neuron_model_path) + synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(neuron_model_path) + del tokenizer # Add dynamic parameters to the model configuration model_config["neuron_model_path"] = neuron_model_path - model_config["neuron_model_id"] = neuron_model_id # Also add model configuration name to allow tests to adapt their expectations model_config["name"] = config_name # Yield instead of returning to keep a reference to the temporary directory. # It will go out of scope and be released only once all tests needing the fixture # have been completed. logger.info(f"{config_name} ready for testing ...") + os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID yield model_config logger.info(f"Done with {config_name}") diff --git a/backends/neuron/tests/server/test_cached_model.py b/backends/neuron/tests/server/test_cached_model.py new file mode 100644 index 00000000..73622578 --- /dev/null +++ b/backends/neuron/tests/server/test_cached_model.py @@ -0,0 +1,42 @@ +import os +import pytest + +from text_generation_server.generator import NeuronGenerator +from text_generation_server.model import fetch_model, is_cached + + +@pytest.fixture(scope="module") +def cached_model_id(neuron_model_config) -> str: + """ + Fixture to provide a cached model ID for testing. + This assumes the model is already cached in the local environment. + """ + export_kwargs = neuron_model_config["export_kwargs"] + os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"]) + os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"]) + os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"] + os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"]) + yield neuron_model_config["model_id"] + os.environ.pop("MAX_BATCH_SIZE", None) + os.environ.pop("MAX_TOTAL_TOKENS", None) + os.environ.pop("HF_AUTO_CAST_TYPE", None) + os.environ.pop("HF_NUM_CORES", None) + + +def test_model_is_cached(cached_model_id): + assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached" + + +def test_fetch_cached_model(cached_model_id: str): + model_path = fetch_model(cached_model_id) + assert os.path.exists( + model_path + ), f"Model {cached_model_id} was not fetched successfully" + assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory" + + +def test_generator_from_cached_model(cached_model_id: str): + generator = NeuronGenerator.from_pretrained(model_id=cached_model_id) + assert generator is not None, "Generator could not be created from cached model" + assert generator.model is not None, "Generator model is not initialized" + assert generator.tokenizer is not None, "Generator tokenizer is not initialized" diff --git a/backends/neuron/tests/server/test_continuous_batching.py b/backends/neuron/tests/server/test_continuous_batching.py index 48bb70cc..3d9ab509 100644 --- a/backends/neuron/tests/server/test_continuous_batching.py +++ b/backends/neuron/tests/server/test_continuous_batching.py @@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config): """ neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) - assert generator.model.batch_size > 1 + assert generator.model.neuron_config.batch_size > 1 input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token tokens = {0: [], 1: []} request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) - max_length = generator.model.max_length + max_length = generator.model.neuron_config.sequence_length batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) generations, next_batch = generator.prefill(batch) assert next_batch.size == 1 diff --git a/backends/neuron/tests/server/test_decode.py b/backends/neuron/tests/server/test_decode.py index 9db5e3ab..b864e3ec 100644 --- a/backends/neuron/tests/server/test_decode.py +++ b/backends/neuron/tests/server/test_decode.py @@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample): request = create_request( id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample ) - max_length = generator.model.max_length + max_length = generator.model.neuron_config.sequence_length batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times @@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample): assert output.finish_reason == 0 if do_sample: expected_text = { - "gpt2": " The sun was set", - "llama": "George Orwell, 1984", - "mistral": "The sky was", - "qwen2": " A young woman with", + "llama": " I sat alone in the café", + "qwen2": " The air was so still", "granite": "1984, George Orwell", }[config_name] assert expected_text in output.text else: print(output.text) expected_text = { - "gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going', - "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story", - "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.", + "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility", "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198", }[config_name] diff --git a/backends/neuron/tests/server/test_prefill.py b/backends/neuron/tests/server/test_prefill.py index c0155b1a..c9ecd1c8 100644 --- a/backends/neuron/tests/server/test_prefill.py +++ b/backends/neuron/tests/server/test_prefill.py @@ -9,7 +9,7 @@ def test_prefill(neuron_model_config): neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) max_batch_size = 4 - assert generator.model.batch_size >= max_batch_size + assert generator.model.neuron_config.batch_size >= max_batch_size for num_requests in [1, max_batch_size]: for do_sample in [True, False]: mode = "sample" if do_sample else "greedy" @@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample): ) ) # Let's be pessimistic when estimating max_tokens - max_length = generator.model.max_length + max_length = generator.max_prefill_length() batch = Batch( id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length ) @@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample): assert len(generations) == batch_size if do_sample: expectations = { - "gpt2": [383, " The"], - "llama": [10058, " George"], - "mistral": [450, " The"], - "qwen2": [362, " A"], + "llama": [358, " I"], + "qwen2": [576, " The"], "granite": [308, " ("], }[config_name] else: expectations = { - "gpt2": [198, "\n"], - "llama": [10058, " George"], - "mistral": [13, "\n"], + "llama": [578, " The"], "qwen2": [358, " I"], "granite": [203, "\n"], }[config_name] @@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config): config_name = neuron_model_config["name"] neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) - batch_size = generator.model.batch_size + batch_size = generator.model.neuron_config.batch_size # We apply truncation to all requests but the first one truncate = [ None, @@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config): requests = [] for i in range(batch_size): requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i])) - max_length = generator.model.max_length + max_length = generator.max_prefill_length() batch = Batch( id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length ) @@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config): # Even if the input text is identical for all requests, the first generated token might # be different because of the truncation expectations = { - "gpt2": [" He", " He", "\n", " He"], - "llama": [" —", " The", " He", " He"], - "mistral": [" He", "\n", " He", " He"], + "llama": [" He", "iens", "\x08", " He"], "qwen2": [" He", " The", " He", " He"], "granite": ["\n", "\n", " I", " He"], }[config_name] for i, g in enumerate(generations): tokens = g.tokens - assert tokens.texts[0] == expectations[i] + assert ( + tokens.texts[0] == expectations[i] + ), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]" diff --git a/backends/neuron/tests/test_entry_point.py b/backends/neuron/tests/test_entry_point.py new file mode 100644 index 00000000..d4ddf338 --- /dev/null +++ b/backends/neuron/tests/test_entry_point.py @@ -0,0 +1,63 @@ +import os +import pytest +from tempfile import TemporaryDirectory + +from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig +from optimum.neuron.utils import map_torch_dtype + +from text_generation_server.tgi_env import ( + get_neuron_config_for_model, + lookup_compatible_cached_model, + neuron_config_to_env, +) + + +def test_get_neuron_config_for_model(neuron_model_config): + neuron_model_path = neuron_model_config["neuron_model_path"] + export_kwargs = neuron_model_config["export_kwargs"] + os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"]) + os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"]) + os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"] + os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"]) + neuron_config = get_neuron_config_for_model(neuron_model_path) + assert neuron_config is not None + assert neuron_config.batch_size == export_kwargs["batch_size"] + assert neuron_config.sequence_length == export_kwargs["sequence_length"] + assert neuron_config.tp_degree == export_kwargs["num_cores"] + if isinstance(neuron_config, NxDNeuronConfig): + assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype( + export_kwargs["auto_cast_type"] + ) + else: + assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype( + export_kwargs["auto_cast_type"] + ) + + +@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"]) +def test_lookup_compatible_cached_model(model_id: str): + neuron_config = lookup_compatible_cached_model(model_id, None) + assert neuron_config is not None + + +def test_neuron_config_to_env(neuron_model_config) -> None: + neuron_model_path = neuron_model_config["neuron_model_path"] + neuron_config = get_neuron_config_for_model(neuron_model_path) + with TemporaryDirectory() as temp_dir: + os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh") + neuron_config_to_env(neuron_config) + with open(os.environ["ENV_FILEPATH"], "r") as env_file: + env_content = env_file.read() + assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content + assert ( + f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}" + in env_content + ) + assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content + if hasattr(neuron_config, "torch_dtype"): + auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split( + "." + )[-1] + else: + auto_cast_type = neuron_config.auto_cast_type + assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh index b959a795..7965d1da 100755 --- a/backends/neuron/tgi-entrypoint.sh +++ b/backends/neuron/tgi-entrypoint.sh @@ -9,7 +9,7 @@ touch $ENV_FILEPATH SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -${SCRIPT_DIR}/tgi_env.py $@ +${SCRIPT_DIR}/tgi_entry_point.py $@ source $ENV_FILEPATH diff --git a/backends/neuron/tgi_entry_point.py b/backends/neuron/tgi_entry_point.py new file mode 100755 index 00000000..7e81d0e4 --- /dev/null +++ b/backends/neuron/tgi_entry_point.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import logging +import os +import sys + + +from text_generation_server.tgi_env import ( + available_cores, + get_env_dict, + get_neuron_config_for_model, + neuron_config_to_env, + neuronxcc_version, + parse_cmdline_and_set_env, + tgi_env_vars, +) + + +logger = logging.getLogger(__name__) + + +def main(): + """ + This script determines proper default TGI env variables for the neuron precompiled models to + work properly + :return: + """ + args = parse_cmdline_and_set_env() + + for env_var in tgi_env_vars: + if not os.getenv(env_var): + break + else: + logger.info( + "All env vars %s already set, skipping, user know what they are doing", + tgi_env_vars, + ) + sys.exit(0) + + neuron_config = get_neuron_config_for_model(args.model_id, args.revision) + + if not neuron_config: + msg = ( + "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}" + ).format(get_env_dict(), available_cores, neuronxcc_version) + logger.error(msg) + raise Exception(msg) + + neuron_config_to_env(neuron_config) + + +if __name__ == "__main__": + main() diff --git a/integration-tests/fixtures/neuron/export_models.py b/integration-tests/fixtures/neuron/export_models.py index 836402ec..d4d0f01c 100644 --- a/integration-tests/fixtures/neuron/export_models.py +++ b/integration-tests/fixtures/neuron/export_models.py @@ -28,15 +28,6 @@ logger = logging.getLogger(__file__) # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { - "gpt2": { - "model_id": "gpt2", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 1024, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, "llama": { "model_id": "unsloth/Llama-3.2-1B-Instruct", "export_kwargs": { @@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = { "auto_cast_type": "fp16", }, }, - "mistral": { - "model_id": "optimum/mistral-1.1b-testing", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 4096, - "num_cores": 2, - "auto_cast_type": "bf16", - }, - }, "qwen2": { "model_id": "Qwen/Qwen2.5-0.5B", "export_kwargs": { diff --git a/integration-tests/neuron/test_generate.py b/integration-tests/neuron/test_generate.py index f0804356..9108ce0e 100644 --- a/integration-tests/neuron/test_generate.py +++ b/integration-tests/neuron/test_generate.py @@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service): ) assert response.details.generated_tokens == 17 greedy_expectations = { - "gpt2": "\n\nDeep learning is a new field of research that has been around for a while", - "llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial", - "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", + "llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial", "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art", } @@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load): assert len(responses) == 4 expectations = { - "gpt2": "Deep learning is a new field of research that has been around for a while", "llama": "Deep learning is a subset of machine learning that uses artificial", - "mistral": "Deep Learning is a type of machine learning that", "qwen2": "Deep Learning is a subset of Machine Learning that is based on", "granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art", } From 839477670aed6498c74b785585ad321ef5f7b3c7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 11 Jun 2025 21:00:21 +0800 Subject: [PATCH 6/9] [gaudi] Perf optimization (#3256) Signed-off-by: Wang, Yi A --- .../layers/attention/__init__.py | 2 + .../layers/attention/hpu.py | 24 ++- .../custom_modeling/flash_cohere_modeling.py | 5 + .../custom_modeling/flash_dbrx_modeling.py | 5 + .../flash_deepseek_v2_modeling.py | 5 + .../flash_deepseek_v3_modeling.py | 5 + .../custom_modeling/flash_gemma2_modeling.py | 5 + .../custom_modeling/flash_gemma_modeling.py | 5 + .../custom_modeling/flash_gpt2_modeling.py | 5 + .../custom_modeling/flash_gptj_modeling.py | 5 + .../custom_modeling/flash_llama4_modeling.py | 5 + .../custom_modeling/flash_llama_modeling.py | 6 + .../custom_modeling/flash_mistral_modeling.py | 5 + .../custom_modeling/flash_mixtral_modeling.py | 6 +- .../custom_modeling/flash_neox_modeling.py | 5 + .../custom_modeling/flash_phi_modeling.py | 5 + .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../custom_modeling/flash_qwen3_modeling.py | 7 +- .../custom_modeling/flash_rw_modeling.py | 5 + .../flash_santacoder_modeling.py | 5 + .../flash_starcoder2_modeling.py | 6 +- .../models/flash_causal_lm.py | 160 ++++++++++++------ .../models/flash_vlm_causal_lm.py | 4 +- .../models/mllama_causal_lm.py | 4 +- 24 files changed, 229 insertions(+), 66 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 370e05bc..aa639832 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -11,6 +11,7 @@ from .hpu import ( attention, paged_attention, paged_attention_mla, + set_block_mapping, ) @@ -22,6 +23,7 @@ __all__ = [ "get_kv_scales", "paged_attention", "paged_attention_mla", + "set_block_mapping", "SUPPORTS_WINDOWING", "KVCache", "KVCompressCache", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 8cca7a29..f12005d2 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA import os from text_generation_server.models.globals import BLOCK_SIZE +import math SUPPORTS_WINDOWING = False @@ -106,6 +107,21 @@ def attention( return attn_output +def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size): + block_mapping = torch.nn.functional.one_hot( + hpu_attention_meta.block_groups, num_classes=batch_size + ) + dtype = hpu_attention_meta.block_usage.dtype + device = hpu_attention_meta.block_usage.device + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + hpu_attention_meta = hpu_attention_meta._replace( + attn_bias=attn_bias, block_mapping=block_mapping.to(dtype) + ) + return hpu_attention_meta + + def paged_attention( query: torch.Tensor, kv_cache: KVCache, @@ -176,4 +192,10 @@ def paged_attention_mla( return output.view(batch_size, head_num, -1) -__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"] +__all__ = [ + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", + "paged_attention_mla", + "set_block_mapping", +] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 801ae09e..7a32a85c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module): seqlen: torch.Tensor, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 76972d38..42af7798 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 6ac7fc1a..8e9002a2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index e0481691..8e058093 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -34,6 +34,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention_mla, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a5860823..a1a20999 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 3d678df1..7a2ec22e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index ed413662..a6b53656 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -27,6 +27,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds residual = None diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index cde03a00..679380a1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 0e3af85a..c6b68f33 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -43,6 +43,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.attention import ( KVCache, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -548,6 +549,10 @@ class Llama4TextModel(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds bs = seqlen.input_lengths.shape[0] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index dfb16621..70fcc824 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -549,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 75d9d360..a4ad8f59 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index f47986d8..4993b444 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales @@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 29620826..6e1050b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 12830991..78aaf0d5 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,6 +9,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 7c7ac03e..ac31e53b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 66a17877..8bd00c13 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -266,7 +267,10 @@ class Qwen3Model(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: - + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 76a2cd01..06616f85 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c64b2ff7..b6a0d32a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 94c60eb6..1a749595 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index f8abe5ad..13a2a307 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -153,19 +153,14 @@ def prepare_for_decode( block_list_device = _async_h2d_tensor_copy(block_list) block_groups_device = _async_h2d_tensor_copy(block_groups) block_usage_device = _async_h2d_tensor_copy(block_usage) - block_mapping = torch.nn.functional.one_hot( - block_groups_device, num_classes=batch_size - ) - mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage_device.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list_device, block_groups=block_groups_device, block_usage=block_usage_device, - block_mapping=block_mapping.to(dtype), - attn_bias=attn_bias, + block_mapping=None, + attn_bias=None, ) ) @@ -428,10 +423,8 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) + # put on cpu temporarily, move to hpu in prepare_for_prefill + all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -701,7 +694,9 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + def concatenate( + cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0 + ) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -750,7 +745,10 @@ class FlashCausalLMBatch(Batch): adapter_meta = None adapter_segment_builder = None else: - input_ids = batches[0].input_ids.new_empty(total_batch_size) + if padded_total_bs == batches[0].input_ids.shape[0]: + input_ids = batches[0].input_ids + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) if ( batches[0].position_ids is not None and batches[0].position_ids.dim() == 2 @@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( - (total_batch_size, max_length) - ) + all_input_ids_tensor = batches[0].all_input_ids_tensor top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) @@ -829,9 +825,12 @@ class FlashCausalLMBatch(Batch): index = torch.tensor(list(range(start_index, end_index)), device="cpu") top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:valid_bsize, :max_length] + if i > 0: + all_input_ids_tensor.index_copy_( + 0, + index.to(batch.all_input_ids_tensor.device), + batch.all_input_ids_tensor[:valid_bsize, :], + ) block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] @@ -851,9 +850,10 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids.index_copy_( - 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] - ) + if padded_total_bs != batches[0].input_ids.shape[0] or i > 0: + input_ids.index_copy_( + 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] + ) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots @@ -987,7 +987,6 @@ class FlashCausalLMBatch(Batch): else: padded_bs = self.input_ids.shape[0] slots = self.slots[self.slot_indices] - extra_pad = padded_bs - self.input_ids.shape[0] self.hpu_attn_meta = prepare_for_decode( dtype, @@ -998,17 +997,20 @@ class FlashCausalLMBatch(Batch): padded_bs, bucketing_ctx, ) - self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0) - self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1) + self.input_ids = F.pad( + self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 + ) + self.position_ids = F.pad( + self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 + ) self.input_lengths_tensor = F.pad( - self.input_lengths_tensor, (0, extra_pad), value=0 + self.input_lengths_tensor, + (0, padded_bs - self.input_lengths_tensor.shape[0]), + value=0, ) self.cache_lengths_tensor = F.pad( - self.cache_lengths_tensor, (0, extra_pad), value=0 - ) - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad), + self.cache_lengths_tensor, + (0, padded_bs - self.cache_lengths_tensor.shape[0]), value=0, ) next_token_chooser_parameters = [] @@ -1028,7 +1030,9 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - def prepare_for_prefill(self, max_padded_input_len, max_padded_bs): + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -1044,7 +1048,7 @@ class FlashCausalLMBatch(Batch): # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length extra_pad_bs = max_padded_bs - len(self) - device = self.all_input_ids_tensor.device + device = "hpu" if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1288,12 +1292,17 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices = ( self.prefill_next_token_indices + input_ids_padded_length_tensor ) - - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad_bs), - value=0, + all_input_ids_tensor = torch.zeros( + (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])), + dtype=torch.int64, + device="hpu", ) + for i in range(len(self)): + all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = ( + self.all_input_ids_tensor[i] + ) + self.all_input_ids_tensor = all_input_ids_tensor + next_token_chooser_parameters = [] next_token_chooser_parameters.extend([r.parameters for r in self.requests]) pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs) @@ -1459,6 +1468,8 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + self.max_total_tokens = None + self.max_input_tokens = None htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -1564,6 +1575,14 @@ class FlashCausalLM(Model): logger.info, f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}", ) + if max_total_tokens is None: + max_total_tokens = sum(batch.input_lengths) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + + self.max_total_tokens = max_total_tokens + self.max_input_tokens = max_input_tokens try: self.init_kv_cache( batch.num_blocks, @@ -1597,11 +1616,6 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") - if max_total_tokens is None: - max_total_tokens = sum(batch.input_lengths) - - if max_input_tokens is None: - max_input_tokens = max_total_tokens - 1 self.kv_cache = [] empty_cache() @@ -2017,7 +2031,9 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], + batch.all_input_ids_tensor[ + : batch.next_token_logits.shape[0], : batch.max_current_length + ], batch.next_token_logits, speculate, batch.speculative_ids, @@ -2031,14 +2047,29 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: - next_token_logprobs = next_token_logprobs.cpu() - accepted_ids = accepted_ids.cpu() - batch.all_input_ids_tensor = batch.all_input_ids_tensor[ - batch.valid_indices - ] - next_input_ids = next_input_ids[batch.valid_indices] - next_token_logprobs = next_token_logprobs[batch.valid_indices] - accepted_ids = accepted_ids[batch.valid_indices] + # TODO speculative decoding handling missing + index = torch.arange( + 0, + len(batch.valid_indices), + device=batch.all_input_ids_tensor.device, + ) + batch.all_input_ids_tensor.index_copy_( + 0, index, batch.all_input_ids_tensor[batch.valid_indices] + ) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + len(batch.valid_indices) + ) + next_input_ids.index_copy_( + 0, index, next_input_ids[batch.valid_indices] + ) + next_input_ids = next_input_ids[:padded_total_bs] + + next_token_logprobs.index_copy_( + 0, index, next_token_logprobs[batch.valid_indices] + ) + accepted_ids.index_copy_( + 0, index, accepted_ids[batch.valid_indices] + ) if speculative_ids is not None: speculative_ids = speculative_ids[batch.valid_indices] batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[ @@ -2106,10 +2137,13 @@ class FlashCausalLM(Model): batch.slot_indices += accepted_ids[: len(batch)] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor + index = F.pad( + index, (0, next_input_ids.shape[0] - index.shape[0]), value=0 + ) index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, - batch.all_input_ids_tensor.shape[0], + index.shape[0], dtype=torch.long, device=batch.all_input_ids_tensor.device, ) @@ -2197,7 +2231,18 @@ class FlashCausalLM(Model): htorch.core.mark_step() # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches) + if self.bucketing_ctx is not None: + total_batch_size = 0 + for b in batches: + total_batch_size += len(b) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + total_batch_size + ) + batch = self.batch_type.concatenate( + batches, padded_total_bs=padded_total_bs + ) + else: + batch = self.batch_type.concatenate(batches) else: batch = batches[0] prefill = batch.prefilling @@ -2208,9 +2253,12 @@ class FlashCausalLM(Model): batch.max_input_length ), self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)), + self.max_total_tokens, ) else: - batch.prepare_for_prefill(batch.max_input_length, len(batch)) + batch.prepare_for_prefill( + batch.max_input_length, len(batch), self.max_total_tokens + ) else: batch.prepare_for_decode( self.dtype, self.use_contiguous_pa, self.bucketing_ctx diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index e604fd3c..9755ee6d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -262,8 +262,8 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 771cc0a8..13939974 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -48,8 +48,8 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super().concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super().concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None From 613b8dd647043c1c23c3587b85e22619caa2a7f3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 13 Jun 2025 04:26:37 +0800 Subject: [PATCH 7/9] [gaudi] Vlm rebase and issue fix in benchmark test (#3263) Signed-off-by: Wang, Yi A --- .../text_generation_server/models/__init__.py | 11 +- .../custom_modeling/flash_llama4_modeling.py | 93 ++- .../custom_modeling/flash_llava_next.py | 201 +++--- .../models/custom_modeling/flash_mllama.py | 25 +- .../flash_pali_gemma_modeling.py | 53 +- .../models/custom_modeling/idefics2.py | 191 +++--- .../models/custom_modeling/idefics3.py | 191 +++--- .../models/custom_modeling/qwen2_5_vl.py | 136 ++-- .../models/custom_modeling/qwen2_vl.py | 129 ++-- .../models/flash_causal_lm.py | 34 +- .../models/flash_vlm_causal_lm.py | 615 ++++++++++++++---- .../models/mllama_causal_lm.py | 16 +- .../models/pali_gemma.py | 71 -- 13 files changed, 1092 insertions(+), 674 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/models/pali_gemma.py diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index c46c79fb..18396e8d 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -83,9 +83,6 @@ try: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, ) - from text_generation_server.models.pali_gemma import ( - PaliGemmaBatch, - ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) @@ -153,7 +150,6 @@ if FLASH_ATTENTION: ) VLM_BATCH_TYPES = { - PaliGemmaBatch, FlashVlmCausalLMBatch, FlashMllamaCausalLMBatch, } @@ -635,6 +631,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) elif model_type == BAICHUAN: return FlashCausalLM( @@ -784,6 +781,8 @@ def get_model( kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + # TODO: Fix bug in rust image_text_replacement implementation + support_chunking=False, ) elif model_type == QWEN2_5_VL: return FlashVlmCausalLM( @@ -799,6 +798,8 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=Qwen2_5_VLConfig, processor_class=Qwen2_5_VLProcessor, + # TODO: Fix bug in rust image_text_replacement implementation + support_chunking=False, ) elif model_type == QWEN3: return FlashCausalLM( @@ -824,6 +825,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) elif model_type == IDEFICS2: return FlashVlmCausalLM( @@ -868,7 +870,6 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, ) elif model_type == LLAVA_NEXT: return FlashVlmCausalLM( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index c6b68f33..3b30f2e0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -1356,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module): hidden_state = self.vision_model(pixel_values) return hidden_state - def forward( + def get_vision_embeds( self, - input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.config.vision_config.vision_feature_layer, + vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy, + image_sizes=image_sizes, + ) + vision_flat = image_features.view(-1, image_features.size(-1)) + image_features = self.multi_modal_projector(vision_flat) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, pixel_values: torch.FloatTensor = None, - pixel_attention_mask=None, - position_ids: Optional[torch.LongTensor] = None, - cu_seqlen_prefill: Optional[torch.Tensor] = None, - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None, - slots: torch.Tensor = None, - seqlen: Seqlen = None, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, List[int]]] = None, - vision_feature_select_strategy: Optional[str] = None, - image_sizes: torch.Tensor = None, - lm_head_indices: Optional[torch.Tensor] = None, - adapter_data: Optional[torch.Tensor] = None, - **lm_kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - - def _get_padding_mask(input_ids, pad_token_id=0): - return (input_ids != pad_token_id).long() - - attention_mask = _get_padding_mask(input_ids) - attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1) + image_sizes: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.text_model.model.embed_tokens(input_ids) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_config.vision_feature_select_strategy - ) - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist original_inputs_embeds_shape = inputs_embeds.shape - - vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( -1 ) @@ -1414,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module): final_mask_1d = final_mask[..., 0].reshape(-1) num_tokens_to_fill = final_mask_1d.sum() - if num_tokens_to_fill != projected_vision_flat.size(0): + if num_tokens_to_fill != vision_embeds.size(0): raise ValueError( f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " - f"but multi_modal_projector returned {projected_vision_flat.size(0)}" + f"but multi_modal_projector returned {vision_embeds.size(0)}" ) expanded_mask = final_mask_1d.unsqueeze(-1).expand( -1, inputs_embeds.size(-1) ) - inputs_embeds = inputs_embeds.masked_scatter( - expanded_mask, projected_vision_flat - ) + inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds) inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlen_prefill: Optional[torch.Tensor] = None, + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None, + slots: torch.Tensor = None, + seqlen: Seqlen = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + lm_head_indices: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + **lm_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: logits, speculative_logits = self.text_model( inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 88548042..c4d4f728 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -163,9 +163,114 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): ) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + return image_features.view(-1, image_features.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -173,101 +278,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - # Unused for this model - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None and len(pixel_values) > 0: - # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() - # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" - # 1. Extract the input embeddings - - # 2. Merge text and images - num_images, num_patches, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view( - num_images * num_patches, channels, height, width - ) - image_features = self.vision_tower(pixel_values) - - # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] - # Already done within the clip model - selected_image_feature = image_features.last_hidden_state - - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise RuntimeError( - f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." - ) - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [num_patches] * num_images - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_features - ) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index 421a0a65..fe6d137b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import ( ) from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA +import habana_frameworks.torch as htorch def _prepare_aspect_ratio_attention_mask( @@ -236,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module): key = key.transpose(1, 2) value = value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1) @@ -320,6 +330,9 @@ class MllamaVisionEncoder(nn.Module): attention_mask: Optional[torch.Tensor] = None, ): encoder_states = [hidden_states] + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, @@ -328,6 +341,8 @@ class MllamaVisionEncoder(nn.Module): hidden_states = layer_outputs encoder_states.append(hidden_states) + if lazy_mode: + htorch.core.mark_step() return hidden_states, encoder_states @@ -699,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module): # key_states = key_states.repeat(1, self.num_key_value_groups, 1) # value_states = value_states.repeat(1, self.num_key_value_groups, 1) - - causal = False # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) @@ -715,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module): value_states, attn_mask=None, dropout_p=0.0, - is_causal=causal, + is_causal=False, scale=None, softmax_mode="None", recompute_mode=None, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 4d31d5dd..a13b9f09 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) + self.dtype = weights.dtype + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_tower(pixel_values) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) + image_features = image_features.view(-1, image_features.shape[-1]) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + mask = input_ids == self.config.image_token_index + inputs_embeds[mask] = vision_embeds + + return inputs_embeds def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -73,32 +103,13 @@ class PaliGemmaForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - # Unused here - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: position_ids += 1 - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) - image_outputs = self.vision_tower(pixel_values) - last_hidden_state = self.post_vision_tower_layernorm( - image_outputs.last_hidden_state - ) - image_features = self.multi_modal_projector(last_hidden_state) - - # mask where image or padding tokens - mask = input_ids == self.config.image_token_index - - # insert image features into input embeddings - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 02806ac9..41a45373 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -734,9 +734,107 @@ class Idefics2ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + assert pixel_values is not None + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + """ + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.gt(patches_subgrid, 0) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -744,98 +842,9 @@ class Idefics2ForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - """ - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - """ - # hpu does none support unfold - conv_kernel = torch.ones( - [1, 1, patch_size, patch_size], - dtype=pixel_values.dtype, - device=pixel_values.device, - ) - patches_subgrid = torch.nn.functional.conv2d( - pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), - conv_kernel, - stride=patch_size, - ).squeeze(1) - patch_attention_mask = torch.eq( - patches_subgrid, (patch_size * patch_size) - ) - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), - ) - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index 964526fc..6dd44c11 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -477,9 +477,107 @@ class Idefics3ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + + """ + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.gt(patches_subgrid, 0) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -487,99 +585,10 @@ class Idefics3ForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - """ - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - """ - # hpu does none support unfold - conv_kernel = torch.ones( - [1, 1, patch_size, patch_size], - dtype=pixel_values.dtype, - device=pixel_values.device, - ) - patches_subgrid = torch.nn.functional.conv2d( - pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), - conv_kernel, - stride=patch_size, - ).squeeze(1) - patch_attention_mask = torch.eq( - patches_subgrid, (patch_size * patch_size) - ) - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - ) - - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 441b0016..a80a86a7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -45,6 +45,11 @@ from text_generation_server.layers.attention import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, ) +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) +import habana_frameworks.torch as htorch # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py from typing import Union @@ -375,28 +380,6 @@ class Qwen2_5_VLConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - class Qwen2_5VLAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() @@ -426,7 +409,8 @@ class Qwen2_5VLAttention(nn.Module): self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state @@ -444,29 +428,37 @@ class Qwen2_5VLAttention(nn.Module): query = query.view(*_shape) key = key.view(*_shape) value = value.view(*_shape) - # apply rotary positional embeddings - query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( - 0 - ) - key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + rotary_dim = cos.shape[-1] + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape)) - # calc maximum sequence length for any batch - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - causal = False + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) # execute sdpa - query = query.unsqueeze(0).transpose(1, 2) - key = key.unsqueeze(0).transpose(1, 2) - value = value.unsqueeze(0).transpose(1, 2) + causal = False + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attention_mask = torch.zeros( + [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = True attn_output = fsdpa_op( query, key, value, - attn_mask=None, + attn_mask=attention_mask, dropout_p=0.0, is_causal=causal, scale=None, @@ -474,7 +466,7 @@ class Qwen2_5VLAttention(nn.Module): recompute_mode=None, valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = attn_output.transpose(0, 1) # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) @@ -533,11 +525,9 @@ class Qwen2_5VLVisionBlock(nn.Module): weights=weights, ) - def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen - ) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor: norm1_out, _ = self.norm1(hidden_states) - attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen) hidden_states = hidden_states + attn_out norm2_out, _ = self.norm2(hidden_states) mlp_out = self.mlp(norm2_out) @@ -608,7 +598,7 @@ class Qwen2_5VisionModel(nn.Module): config=config, weights=weights, ) - # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size self.spatial_patch_size = config.spatial_patch_size self.in_channels = config.in_channels @@ -736,6 +726,10 @@ class Qwen2_5VisionModel(nn.Module): ) rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = torch.cat((cos, cos), dim=-1).unsqueeze(1) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(1) cu_window_seqlens = torch.tensor( cu_window_seqlens, @@ -754,6 +748,9 @@ class Qwen2_5VisionModel(nn.Module): max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for layer_num, block in enumerate(self.blocks): # NOTE: qwen2_5_vl.py has a concept of full attention blocks # that are applied at specific layers. @@ -762,9 +759,9 @@ class Qwen2_5VisionModel(nn.Module): else: cu_seqlens_now = cu_window_seqlens - hidden_states = block( - hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen - ) + hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen) + if lazy_mode: + htorch.core.mark_step() # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) @@ -886,9 +883,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): full_llm_pos_ids_list = [ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist ] - # import ipdb - - # ipdb.set_trace() max_s = full_llm_pos_ids_list[-1].max() + 1 final_text_len = input_ids_len - vision_ends[-1] if final_text_len > 0: @@ -900,9 +894,33 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): ) return position_ids - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None: + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = vision_embeds + + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -910,26 +928,10 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - # Unused in this model - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - mask = torch.where(input_ids == self.image_token_id) - inputs_embeds[mask] = image_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 47ae2ac9..96acef31 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -44,28 +44,11 @@ from text_generation_server.layers.attention import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) +import habana_frameworks.torch as htorch class Qwen2VLAttention(nn.Module): @@ -96,7 +79,8 @@ class Qwen2VLAttention(nn.Module): self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state @@ -116,27 +100,36 @@ class Qwen2VLAttention(nn.Module): value = value.view(*_shape) # apply rotary positional embeddings - query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( - 0 - ) - key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + rotary_dim = cos.shape[-1] + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape)) - # calc maximum sequence length for any batch - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - causal = False + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) # execute sdpa - query = query.unsqueeze(0).transpose(1, 2) - key = key.unsqueeze(0).transpose(1, 2) - value = value.unsqueeze(0).transpose(1, 2) + causal = False + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attention_mask = torch.zeros( + [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = True attn_output = fsdpa_op( query, key, value, - attn_mask=None, + attn_mask=attention_mask, dropout_p=0.0, is_causal=causal, scale=None, @@ -144,7 +137,7 @@ class Qwen2VLAttention(nn.Module): recompute_mode=None, valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = attn_output.transpose(0, 1) # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = self.proj(attn_output) @@ -193,11 +186,9 @@ class Qwen2VLVisionBlock(nn.Module): weights=weights, ) - def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen - ) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor: norm1_out, residual = self.norm1(hidden_states) - attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen) hidden_states = attn_out + residual norm2_out, residual = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(norm2_out) @@ -330,6 +321,11 @@ class Qwen2VisionModel(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = torch.cat((cos, cos), dim=-1).unsqueeze(1) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(1) + # create a cu_seqlens tensor to be used in the attention mask cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -337,8 +333,13 @@ class Qwen2VisionModel(nn.Module): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for block in self.blocks: - hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen) + if lazy_mode: + htorch.core.mark_step() # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) @@ -474,9 +475,33 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) return position_ids - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None: + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = vision_embeds + + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -484,26 +509,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - mask = torch.where(input_ids == self.image_token_id) - inputs_embeds[mask] = image_embeds - hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 13a2a307..cb8c742e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1000,9 +1000,18 @@ class FlashCausalLMBatch(Batch): self.input_ids = F.pad( self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 ) - self.position_ids = F.pad( - self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 - ) + + if self.position_ids.dim() == 2: + # Qwen VL case + self.position_ids = F.pad( + self.position_ids, + (0, 0, 0, padded_bs - self.position_ids.shape[0]), + value=1, + ) + else: + self.position_ids = F.pad( + self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 + ) self.input_lengths_tensor = F.pad( self.input_lengths_tensor, (0, padded_bs - self.input_lengths_tensor.shape[0]), @@ -1066,8 +1075,19 @@ class FlashCausalLMBatch(Batch): input_ids = [0] * extra_pad + input_ids self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: - self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) - input_ids_padded_length.extend([extra_pad] * len(self)) + input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self)) + src_pos = 0 + for i in range(len(self)): + end_pos = (i + 1) * max_padded_input_len + start_pos = end_pos - self.input_lengths[i] + input_ids[start_pos:end_pos] = self.input_ids[ + src_pos : src_pos + self.input_lengths[i] + ] + input_ids_padded_length.append( + max_padded_input_len - self.input_lengths[i] + ) + src_pos += self.input_lengths[i] + self.input_ids = input_ids self.input_ids = F.pad( self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0 @@ -1457,7 +1477,7 @@ class FlashCausalLM(Model): if head_size is None: # Some models use GQA and different sizes for o_proj # and q_proj, that allows for that. - if hasattr(config, "head_dim"): + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = config.hidden_size // config.num_attention_heads @@ -2263,6 +2283,8 @@ class FlashCausalLM(Model): batch.prepare_for_decode( self.dtype, self.use_contiguous_pa, self.bucketing_ctx ) + if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds): + self.set_inputs_embeds(batch) prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 9755ee6d..086c05e7 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -1,7 +1,7 @@ import torch from PIL import Image from io import BytesIO - +from dataclasses import dataclass from opentelemetry import trace from typing import Iterable, Optional, Tuple, List, Type, Dict @@ -119,17 +119,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def image_text_replacement(processor, image_input, config, image_id: int) -> str: +def image_text_replacement(processor, image_input, config) -> str: if config.model_type == "idefics2": image_seq_len = 64 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" if processor.image_processor.do_image_splitting: image_str *= 5 - return image_str + return image_str, IDEFICS2_FAKE_TOKEN if config.model_type == "idefics3": # TODO: implement this in a more general way - n_rows = image_input["rows"][0][image_id] - n_cols = image_input["cols"][0][image_id] + n_rows = image_input["rows"][0][0] + n_cols = image_input["cols"][0][0] image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) @@ -142,41 +142,41 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_token=IDEFICS3_IMAGE_TOKEN, global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) - return image_str + return image_str, IDEFICS3_FAKE_IMAGE_TOKEN elif config.model_type == "llava_next": - height, width = image_input["image_sizes"][image_id] + height, width = image_input["image_sizes"][0] num_features = get_number_of_features(height, width, config) log_master( logger.info, f"Found {num_features} features in image of resolution {height}x{width}", ) - return "" * num_features + return "" * num_features, "" elif config.model_type == "paligemma": - return "" * config.text_config.num_image_tokens + return "" * config.text_config.num_image_tokens, "" elif config.model_type == "qwen2_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "qwen2_5_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "gemma3": # TODO: get correct number of features via reviewing the Gemma3 architecture # and calculating the number of image tokens num_pads = 256 padding = "" * num_pads - return f"\n\n{padding}\n\n" + return f"\n\n{padding}\n\n", "" elif config.model_type == "llama4": patch_size = config.vision_config.patch_size pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) - aspect_ratios = image_input["aspect_ratios"][image_id] - image_height, image_width = image_input["pixel_values"][image_id].shape[-2:] + aspect_ratios = image_input["aspect_ratios"][0] + image_height, image_width = image_input["pixel_values"][0].shape[-2:] num_patches_per_chunk = int( (image_height // patch_size) @@ -187,7 +187,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str aspect_ratios, num_patches_per_chunk ) - return tokens_for_this_image + return tokens_for_this_image, "<|image_start|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -200,6 +200,27 @@ def image_text_replacement_fixup(config, text: str) -> str: return text +def preprocess_text(config, text: str) -> str: + if config.model_type == "paligemma": + return "" + text + "\n" + return text + + +def preprocess_image(config, img): + model_type = config.model_type + + if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: + img = img.resize((img.width * 2, img.height * 2)) + if model_type == "paligemma": + img = img.convert("RGB") + + if model_type not in {"llava_next", "gemma3", "llama4"}: + # TODO: check if this is needed + img = [img] + + return img + + def get_unpadded_features( original_height: int, original_width: int, @@ -254,105 +275,259 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features +def scatter_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> torch.Tensor: + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed.to(embeds.device)] = embeds + return placeholders + + +def gather_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> Optional[torch.Tensor]: + if is_embed is None: + return embeds + sel = embeds[is_embed.to(embeds.device)] + return sel if sel.numel() else None + + +@dataclass +class ImagePositions: + offset: int + length: int + id: int + num_placeholder_tokens: int + is_embed: Optional[torch.Tensor] = None + + class FlashVlmCausalLMBatch(FlashCausalLMBatch): + image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]] + image_positions: Optional[List[List[ImagePositions]]] + encoder_cache: Optional[List[Dict[int, torch.Tensor]]] pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] image_grid_thw: Optional[torch.Tensor] + cache_entries_to_free: List[Tuple[int, int]] + has_image_inputs: bool = False + inputs_embeds: Optional[torch.Tensor] = None @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches, padded_total_bs: int = 0): batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) + batch.image_inputs = [] + batch.image_positions = [] + batch.encoder_cache = [] + for b in batches: + if b.image_inputs is not None: + batch.image_inputs.extend(b.image_inputs) + else: + batch.image_inputs.append(None) + if b.image_positions is not None: + batch.image_positions.extend(b.image_positions) + else: + batch.image_positions.append(None) + if b.encoder_cache is not None: + batch.encoder_cache.extend(b.encoder_cache) + else: + batch.encoder_cache.append(None) + batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] return batch @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + + image_inputs = [] + image_positions = [] + encoder_cache = [] + + for request_id in request_ids: + idx = self.requests_idx_mapping[request_id] + image_inputs.append(self.image_inputs[idx]) + image_positions.append(self.image_positions[idx]) + encoder_cache.append(self.encoder_cache[idx]) + batch = super().filter(request_ids) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = encoder_cache + + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] return batch @classmethod def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): - # Process images first. We need all of them so that the processor - # can make the image splits the same size. And we need the final - # sizes to insert correct number of image tokens. - images = [] + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + max_length = 0 + vocab = tokenizer.get_vocab() + + if not hasattr(config, "image_token_index"): + config.image_token_index = config.image_token_id + + batch_tokenized_inputs: List[List[int]] = [] + batch_image_inputs: List[Optional[List[dict]]] = [] + batch_image_positions: List[Optional[List[ImagePositions]]] = [] + for r in requests: + text_parts = [] + image_inputs = [] + image_texts = [] + + image_id = 0 + for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": - pass + text = preprocess_text(config, chunk.text) + text_parts.append(text) elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the - # default warmup image is 20x20 - if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: - if image.width <= 20: - w = image.width * 2 - h = image.height * 2 - image = image.resize((w, h)) + img = Image.open(BytesIO(chunk.image.data)) + img = preprocess_image(config, img) - if config.model_type == "llava_next": - images.append(image) - elif config.model_type == "gemma3": - images.append(image) - elif config.model_type == "llama4": - images.append(image) - else: - images.append([image]) + image_input = processor.image_processor( + [img], return_tensors="pt", **kwargs + ) + image_inputs.append(image_input) + + img_text, img_start_token_str = image_text_replacement( + processor, image_input, config + ) + text_parts.append(img_text) + + image_texts.append([image_id, img_start_token_str, img_text]) + image_id += 1 else: raise RuntimeError(f"Invalid chunk type {chunk_type}") - if images: - kwargs = {} - if ( - hasattr(processor, "image_processor_class") - and processor.image_processor_class == "Idefics3ImageProcessor" - ): - kwargs["return_row_col_info"] = True - - image_inputs = processor.image_processor( - images, return_tensors="pt", **kwargs - ) - else: - image_inputs = None - - batch_tokenized_inputs = [] - max_length = 0 - image_id = 0 - for r in requests: - full_text = "" - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += chunk.text - elif chunk_type == "image": - full_text += image_text_replacement( - processor, image_inputs, config, image_id - ) - image_id += 1 - - full_text = image_text_replacement_fixup(config, full_text) + full_text = image_text_replacement_fixup(config, "".join(text_parts)) input_ids = tokenizer( full_text, truncation=True, max_length=r.truncate, - add_special_tokens=r.add_special_tokens, + add_special_tokens=( + r.add_special_tokens if config.model_type != "paligemma" else False + ), )["input_ids"] max_length = max(max_length, len(input_ids)) - batch_tokenized_inputs.append(input_ids) - return batch_tokenized_inputs, image_inputs + if len(image_inputs) > 0: + img_start_token = vocab[image_texts[0][1]] + image_positions = cls.get_image_positions( + input_ids, image_texts, img_start_token, config, tokenizer + ) + else: + image_inputs = None + image_positions = None + + batch_tokenized_inputs.append(input_ids) + batch_image_inputs.append(image_inputs) + batch_image_positions.append(image_positions) + + return batch_tokenized_inputs, batch_image_inputs, batch_image_positions + + @classmethod + def get_image_positions( + cls, + input_ids: List[int], + image_texts: List[Tuple[int, str, str]], + img_start_token: int, + config, + tokenizer: PreTrainedTokenizerBase, + ) -> List[ImagePositions]: + image_positions = [] + num_images = len(image_texts) + + input_ids_t = torch.as_tensor(input_ids) + img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0] + num_tokens = input_ids_t.numel() + + last_pos = 0 + for i in range(num_images): + image_id, img_start_token_str, img_text = image_texts[i] + img_text = image_text_replacement_fixup(config, img_text) + + if config.model_type == "gemma3": + img_text = img_text.replace("\n\n", "") + + tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ][0] + length = tokens.numel() + + assert ( + length <= num_tokens + ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens" + + pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) + index = img_start_token_pos[pos] + assert torch.equal( + input_ids_t[index : index + length], tokens + ), "Image tokens not found in input_ids" + + is_embed = tokens == config.image_token_index + num_placeholder_tokens = int(is_embed.sum()) + if num_placeholder_tokens == length: + is_embed = None + + pos = ImagePositions( + offset=index, + length=length, + id=image_id, + num_placeholder_tokens=num_placeholder_tokens, + is_embed=is_embed, + ) + + image_positions.append(pos) + last_pos = index + length + + if ( + config.model_type == "idefics2" + and i + 1 != num_images + and input_ids[last_pos] == config.image_token_index + ): + fake_token = last_pos - 1 + fake_token_index = torch.searchsorted( + img_start_token_pos, fake_token, right=False + ) + img_start_token_pos[fake_token_index] = last_pos + image_texts[i + 1][2] = image_texts[i + 1][2][ + len(img_start_token_str) : + ] + + return image_positions @classmethod def from_pb_processor( @@ -364,33 +539,164 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "FlashVlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config + batch_tokenized_inputs, image_inputs, image_positions = ( + cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config) ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "image_grid_thw" in image_inputs: - batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) - else: - batch.image_grid_thw = None - else: + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = [{} for _ in range(len(pb.requests))] + if len(image_inputs): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None return batch + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): + super().prepare_for_prefill( + max_padded_input_len, max_padded_bs, max_total_tokens + ) + + self.has_image_inputs = False + self.cache_entries_to_free = [] + + self.pixel_values = [] + + assert ( + len(self.cache_lengths) + == len(self.input_lengths) + == len(self.prefilling_mask) + ), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask" + + for i, ( + cache_length, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ) + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for image_position in self.image_positions[i]: + if image_position is None: + continue + start_pos = image_position.offset + length = image_position.length + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + self.has_image_inputs = True + + if image_position.id not in self.encoder_cache[i]: + image_inputs = self.image_inputs[i][image_position.id] + self.pixel_values.append((i, image_position.id, image_inputs)) + + # Remove the image from the image_inputs + self.image_inputs[i][image_position.id] = None + + if not self.has_image_inputs: + self.pixel_values = None + self.pixel_attention_mask = None + self.image_sizes = None + self.image_grid_thw = None + else: + image_grid_thw_list = [ + x[2]["image_grid_thw"] + for x in self.pixel_values + if "image_grid_thw" in x[2] + ] + if image_grid_thw_list: + self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0) + else: + self.image_grid_thw = None + + def update_encoder_cache(self, encoder_outputs, request_id, img_pos): + self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( + encoder_outputs, img_pos.is_embed + ) + + def gather_vision_embeds(self): + device = self.input_ids.device + chunks = [] + for ( + i, + cache_length, + input_length, + request_prefilling, + ) in zip( + range(len(self.requests)), + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for image_position in self.image_positions[i]: + if image_position is None: + continue + start_pos = image_position.offset + length = image_position.length + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + start_idx = max(cache_length - start_pos, 0) + end_idx = min(cache_length - start_pos + input_length, length) + + assert ( + image_position.id in self.encoder_cache[i] + ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}" + encoder_output = self.encoder_cache[i][image_position.id] + + is_embed = image_position.is_embed + if is_embed is not None: + is_embed = is_embed[start_idx:end_idx] + + from loguru import logger + + logger.info( + f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}" + ) + + embeds = gather_image_embeds( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + if embeds is not None: + chunks.append(embeds) + + if end_idx == length: + self.cache_entries_to_free.append((i, image_position.id)) + self.image_positions[i][image_position.id] = None + + if len(chunks) == 0: + return None + return torch.cat(chunks, dim=0).to(device) + + def free_encoder_cache(self): + for i, image_id in self.cache_entries_to_free: + self.encoder_cache[i].pop(image_id, None) + + self.cache_entries_to_free = [] + class FlashVlmCausalLM(FlashCausalLM): def __init__( @@ -402,6 +708,7 @@ class FlashVlmCausalLM(FlashCausalLM): batch_class=FlashVlmCausalLMBatch, revision, trust_remote_code: bool, + support_chunking: bool = False, **kwargs, ): if PREFIX_CACHING: @@ -419,8 +726,7 @@ class FlashVlmCausalLM(FlashCausalLM): model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, - # FIXME: VLM do not work with context chunking yet - support_chunking=False, + support_chunking=support_chunking, **kwargs, ) @@ -471,9 +777,12 @@ class FlashVlmCausalLM(FlashCausalLM): bucketing_ctx=None, ) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) + inputs_embeds = self.get_inputs_embeds( + input_ids=input_ids.to(self.device), + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + inputs_embeds=inputs_embeds, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=None, kv_cache=self.kv_cache, @@ -481,10 +790,7 @@ class FlashVlmCausalLM(FlashCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, - pixel_values=None, - pixel_attention_mask=None, - image_sizes=None, - image_grid_thw=None, + attention_mask=None, ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): @@ -546,6 +852,84 @@ class FlashVlmCausalLM(FlashCausalLM): f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", ) + def get_vision_embeds( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + image_sizes: torch.Tensor, + image_grid_thw: torch.Tensor, + ): + embeds = self.model.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + return embeds + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: Optional[torch.Tensor] = None, + ): + return self.model.get_inputs_embeds( + input_ids=input_ids, + vision_embeds=vision_embeds, + ) + + def encode_images(self, batch): + if batch.pixel_values is not None: + device = batch.input_ids.device + for request_id, image_id, image_input in batch.pixel_values: + pixel_values = image_input["pixel_values"].to(device) + + if "pixel_attention_mask" in image_input: + pixel_attention_mask = image_input["pixel_attention_mask"].to( + device + ) + else: + pixel_attention_mask = None + + if "image_sizes" in image_input: + image_sizes = image_input["image_sizes"].to(device) + else: + image_sizes = None + + if "image_grid_thw" in image_input: + image_grid_thw = image_input["image_grid_thw"] + else: + image_grid_thw = None + + encoder_outputs = self.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + batch.update_encoder_cache( + encoder_outputs, + request_id, + batch.image_positions[request_id][image_id], + ) + + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + + def set_inputs_embeds(self, batch): + if batch.has_image_inputs: + self.encode_images(batch) + vision_embeds = batch.gather_vision_embeds() + batch.has_image_inputs = False + else: + vision_embeds = None + + inputs_embeds = self.get_inputs_embeds( + batch.input_ids, vision_embeds=vision_embeds + ) + + batch.inputs_embeds = inputs_embeds + def forward( self, batch: FlashVlmCausalLMBatch, @@ -593,6 +977,7 @@ class FlashVlmCausalLM(FlashCausalLM): position_ids = new_position_ids else: input_ids = batch.input_ids + inputs_embeds = batch.inputs_embeds position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache @@ -605,10 +990,25 @@ class FlashVlmCausalLM(FlashCausalLM): if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( - input_ids, batch.image_grid_thw + input_ids.cpu(), batch.image_grid_thw ) batch.position_ids = position_ids + attention_mask = None + attention_mask_forward = None + if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ) + min_dtype = torch.finfo(self.dtype).min + attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( + input_ids.device + ) + attention_mask = attention_mask.reshape(-1) + if self.model.config.model_type == "llama4": + attention_mask = (input_ids != 0).long() + attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1) + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. @@ -639,7 +1039,7 @@ class FlashVlmCausalLM(FlashCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, + inputs_embeds=inputs_embeds, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, @@ -647,18 +1047,11 @@ class FlashVlmCausalLM(FlashCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - image_grid_thw=batch.image_grid_thw, + attention_mask=attention_mask_forward, **kwargs, ) - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - if batch.image_grid_thw is not None: - batch.image_grid_thw = None + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + batch.image_grid_thw = None + batch.free_encoder_cache() return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 13939974..1be36d09 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -46,10 +46,17 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): aspect_ratio_mask: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): + super(FlashVlmCausalLMBatch, self).prepare_for_prefill( + max_padded_input_len, max_padded_bs, max_total_tokens + ) + @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches, padded_total_bs: int = 0): - batch = super().concatenate(batches, padded_total_bs) + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None @@ -73,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): assert self.image_indices is not None - batch = super().filter(request_ids) + batch = super(FlashVlmCausalLMBatch, self).filter(request_ids) assert self.image_indices is not None indices = [] for i, request_id in enumerate(request_ids): @@ -99,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): ] else: batch.cross_attention_states = None + batch.pixel_values = None return batch @classmethod @@ -228,6 +236,10 @@ def generate_cross_attention_states( class FlashMllamaCausalLM(FlashVlmCausalLM): + def set_inputs_embeds(self, batch): + # Set the input embeddings to None, as we are using the input_ids for the model + batch.inputs_embeds = None + def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch ): diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py deleted file mode 100644 index e91aaed9..00000000 --- a/backends/gaudi/server/text_generation_server/models/pali_gemma.py +++ /dev/null @@ -1,71 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import torch.distributed -from opentelemetry import trace -from typing import Iterable -from text_generation_server.models.flash_vlm_causal_lm import ( - FlashVlmCausalLMBatch, - image_text_replacement, -) - -from text_generation_server.pb.generate_pb2 import Request - -tracer = trace.get_tracer(__name__) - - -class PaliGemmaBatch(FlashVlmCausalLMBatch): - @classmethod - def batch_tokenized_inputs( - cls, requests: Iterable[Request], tokenizer, processor, config - ): - batch_inputs = [] - image_inputs = [] - max_truncation = 0 - for r in requests: - full_text = "" - image_id = 0 - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += "" + chunk.text + "\n" - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO do_convert_RGB should be on by default ? - image = image.convert("RGB") - image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement( - processor, image_input, config, image_id - ) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=False, - )["input_ids"] - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None - return batch_tokenized_inputs, image_inputs From 25fdc5f03c2e1666cbb4ed9f4315d4cb0a6107ef Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Fri, 13 Jun 2025 04:31:11 +0800 Subject: [PATCH 8/9] [gaudi] Move the _update_cos_sin_cache into get_cos_sin (#3254) Signed-off-by: yuanwu --- .../text_generation_server/layers/rotary.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 7e740e5f..d381d4c6 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, inv_freq.device, max_position_embeddings - ) + self.max_position_embeddings = max_position_embeddings def forward( self, @@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor): - + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=self.max_position_embeddings + ) cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -598,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): position_ids: torch.Tensor, ): slen = position_ids.shape[0] + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=self.max_position_embeddings + ) cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) From e07056ab3f0a8a6e748bcaf766508385fcd4a7fa Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Fri, 13 Jun 2025 04:35:36 +0800 Subject: [PATCH 9/9] [Gaudi] Remove optimum-habana (#3261) Signed-off-by: yuanwu --- Dockerfile_gaudi | 2 +- backends/gaudi/server/pyproject.toml | 5 +- backends/gaudi/server/requirements.txt | 5 +- .../server/text_generation_server/cli.py | 89 +- .../habana_quantization_env.py | 53 - .../text_generation_server/models/__init__.py | 73 +- .../text_generation_server/models/bloom.py | 52 - .../models/causal_lm.py | 1444 --------------- .../models/custom_modeling/llava_next.py | 467 ----- .../models/custom_modeling/mllama.py | 292 --- .../models/custom_modeling/qwen2_5_vl.py | 3 +- .../models/galactica.py | 156 -- .../text_generation_server/models/globals.py | 4 +- .../models/idefics_causal_lm.py | 882 --------- .../text_generation_server/models/mamba.py | 814 --------- .../models/starcoder.py | 47 - .../models/vlm_causal_lm.py | 1609 ----------------- backends/gaudi/tgi-entrypoint.sh | 8 - launcher/src/env_runtime.rs | 4 - launcher/src/main.rs | 9 - 20 files changed, 23 insertions(+), 5995 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/habana_quantization_env.py delete mode 100644 backends/gaudi/server/text_generation_server/models/bloom.py delete mode 100644 backends/gaudi/server/text_generation_server/models/causal_lm.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py delete mode 100644 backends/gaudi/server/text_generation_server/models/galactica.py delete mode 100644 backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py delete mode 100644 backends/gaudi/server/text_generation_server/models/mamba.py delete mode 100644 backends/gaudi/server/text_generation_server/models/starcoder.py delete mode 100644 backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 442eb6b7..02885405 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -57,7 +57,7 @@ ARG PYTORCH_VERSION FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base -ENV ATTENTION=default +ENV ATTENTION=paged ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV PT_HPU_LAZY_MODE=1 diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml index 3f2676cb..fa2c2697 100644 --- a/backends/gaudi/server/pyproject.toml +++ b/backends/gaudi/server/pyproject.toml @@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0" hf-transfer = "^0.1.9" sentencepiece = "^0.2.0" peft = "^0.15" -optimum-habana = "1.17" -transformers = "^4.49" +transformers = "^4.52.4" numpy = "^1.26" -accelerate = "^0.33" +accelerate = "^1.7.0" outlines= { version = "^0.0.36", optional = true } prometheus-client = "^0.21.1" py-cpuinfo = "^9.0.0" diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt index 6f897722..e6c9abf2 100644 --- a/backends/gaudi/server/requirements.txt +++ b/backends/gaudi/server/requirements.txt @@ -1,4 +1,4 @@ -accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" +accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13" attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13" @@ -46,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13" -optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13" optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13" outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13" packaging==24.2 ; python_version >= "3.9" and python_version < "3.13" @@ -76,7 +75,7 @@ sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13" triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index d4445a13..dc31ab2f 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -1,6 +1,4 @@ import os -import psutil -import signal import sys import typer @@ -115,80 +113,19 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - - logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - - if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" - num_shard = int(os.getenv("WORLD_SIZE", "1")) - logger.info("CLI SHARDED = {}".format(num_shard)) - import subprocess - - cmd = ( - f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" - ) - cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" - cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" - cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" - if speculate is not None: - cmd += f"--speculate {speculate}" - logger.info("CLI server start deepspeed ={} ".format(cmd)) - sys.stdout.flush() - sys.stderr.flush() - with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - do_terminate = False - current_handler = signal.getsignal(signal.SIGTERM) - - def terminate_handler(sig, frame): - nonlocal do_terminate - do_terminate = True - if callable(current_handler): - current_handler(sig, frame) - - signal.signal(signal.SIGTERM, terminate_handler) - - finished = False - while not finished: - try: - if do_terminate: - parent = psutil.Process(proc.pid) - all_procs = parent.children(recursive=True) + [parent] - for p in all_procs: - try: - p.terminate() - except psutil.NoSuchProcess: - pass - _, alive = psutil.wait_procs(all_procs, timeout=30) - for p in alive: - p.kill() - - do_terminate = False - - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - pass - else: - finished = True - - sys.stdout.flush() - sys.stderr.flush() - if proc.returncode != 0: - logger.error(f"{cmd} exited with status = {proc.returncode}") - return proc.returncode - else: - server.serve( - model_id, - lora_adapters, - revision, - sharded, - quantize, - speculate, - dtype, - kv_cache_dtype, - trust_remote_code, - uds_path, - max_input_tokens, - ) + server.serve( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + kv_cache_dtype, + trust_remote_code, + uds_path, + max_input_tokens, + ) @app.command() diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py deleted file mode 100644 index b03b7e26..00000000 --- a/backends/gaudi/server/text_generation_server/habana_quantization_env.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import os -import habana_frameworks.torch as htorch - -quant_config = os.getenv("QUANT_CONFIG", "") -is_quantization_enabled = quant_config != "" - -if is_quantization_enabled: - os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") - os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") - os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") - os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") - os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") - - -def patch_scoped_linear_all_reduce(model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ( - ScopedLinearAllReduce, - ) - - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - patch_scoped_linear_all_reduce(module) - - -def setup_quantization(model): - if is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - -def prepare_model_for_quantization(model): - if is_quantization_enabled: - if model.config.model_type in [ - "llama", - "falcon", - "qwen2", - "starcoder2", - "gemma", - ]: - patch_scoped_linear_all_reduce(model) - from neural_compressor.torch.quantization import FP8Config, convert - - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - return model diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 18396e8d..c4943463 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -5,7 +5,6 @@ import os from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path @@ -36,14 +35,10 @@ __all__ = [ "Seq2SeqLM", "get_model_with_lora_adapters", ] -from text_generation_server.models.globals import ATTENTION VLM_BATCH_TYPES = set() -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = False -if ATTENTION == "paged": - FLASH_ATTENTION = True +FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -883,72 +878,6 @@ def get_model( trust_remote_code=trust_remote_code, ) - from text_generation_server.models.causal_lm import CausalLM - from text_generation_server.models.vlm_causal_lm import VlmCausalLM - from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, - ) - from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, - ) - from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, - ) - - VLM_BATCH_TYPES.add(VlmCausalLMBatch) - - from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - - adapt_transformers_to_gaudi() - if SDP_ON_BF16 == 1: - torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) - if model_type == "gpt_bigcode": - from text_generation_server.models.starcoder import StarCoder - - return StarCoder(model_id=model_id, revision=revision, dtype=dtype) - if model_type == "bloom": - from text_generation_server.models.bloom import BLOOM - - return BLOOM( - model_id=model_id, - revision=revision, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == "llava_next": - return VlmCausalLM( - model_class=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == "mllama": - return VlmCausalLM( - model_class=MllamaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - raise ValueError(f"Unsupported model type {model_type}") diff --git a/backends/gaudi/server/text_generation_server/models/bloom.py b/backends/gaudi/server/text_generation_server/models/bloom.py deleted file mode 100644 index 6fe64374..00000000 --- a/backends/gaudi/server/text_generation_server/models/bloom.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import torch - -from typing import Optional, Type - -from transformers import PreTrainedTokenizerBase - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 - - -class BloomCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb( - pb=pb, - tokenizer=tokenizer, - dtype=dtype, - device=device, - ) - batch.keys_head_dim_last = False - return batch - - -class BLOOM(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(BLOOM, self).__init__( - model_id=model_id, - revision=revision, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py deleted file mode 100644 index dd6e070d..00000000 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ /dev/null @@ -1,1444 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import bisect -from dataclasses import dataclass -from functools import wraps -import itertools -import json -import math -import os -import tempfile -import time -import copy -from typing import Dict, List, Optional, Tuple, Type - -import torch -import torch._dynamo -from loguru import logger -from opentelemetry import trace - -import text_generation_server.habana_quantization_env as hq_env -from text_generation_server.utils import weight_files -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from text_generation_server.utils.chunks import concat_text_chunks -from optimum.habana.checkpoint_utils import model_on_meta -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - PreTrainedTokenizerBase, - AutoConfig, -) - -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - StoppingCriteria, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -from optimum.habana.utils import get_hpu_memory_stats -from text_generation_server.utils.debug import dbg_trace -from text_generation_server.utils.speculate import get_speculate - -tracer = trace.get_tracer(__name__) -MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2)) -SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2)) -MAX_BATCH_SIZE = ( - int(os.environ.get("MAX_BATCH_SIZE")) - if os.environ.get("MAX_BATCH_SIZE") is not None - else None -) - - -def torch_compile_for_eager(func): - if LAZY_MODE == 1: - return func - return torch.compile( - func, backend="hpu_backend", options={"keep_input_mutations": True} - ) - - -def round_up_seq(number, k, base): - exponent = max(0, math.ceil(math.log(number / k, base))) - return int(k * (base**exponent)) - - -def iterate_powers_of_base(max_value, start, base): - current = start - result = [] - assert ( - max_value >= start - ), f"max_value {max_value} must be greater than start {start}" - while current < max_value: - result.append(current) - current *= base - return result - - -def round_up_batch(number): - return BATCH_SIZE_EXPONENT_BASE ** ( - math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE)) - ) - - -def to_tensor_indices(indices, device): - return torch.tensor(indices, dtype=torch.long, device=device) - - -def calculate_chunks(offset): - result = [] - while offset != 0: - sign = 1 if offset > 0 else -1 - best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] - result.append(best_chunk) - offset = offset - best_chunk - return result - - -def biggest_single_chunk(offset): - if offset != 0: - idx = bisect.bisect(CHUNK_SIZES, abs(offset)) - return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) - else: - return 0 - - -@torch_compile_for_eager -def grouped_pad(tensor_groups, dims, values): - grouped_result = [] - for tensors, dim, value in zip(tensor_groups, dims, values): - padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 - if padding > 0: - assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}" - pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - result = [ - torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors - ] - else: - result = [t for t in tensors] - grouped_result.append(result) - htorch.core.mark_step() - return grouped_result - - -@torch_compile_for_eager -def roll(tensor, chunk, dim, merge_graphs): - if dim is None: - return tensor - tensor = torch.roll(tensor, chunk, dim) - if not merge_graphs: - htorch.core.mark_step() - return tensor - - -def grouped_roll(tensor_groups, chunk, dims, merge_graphs): - tensor_groups = [ - [roll(t, chunk, dim, merge_graphs) for t in tensors] - for tensors, dim in zip(tensor_groups, dims) - ] - if merge_graphs: - htorch.core.mark_step() - return tensor_groups - - -@torch_compile_for_eager -def grouped_shift(tensor_groups, dims, offset, merge_graphs): - chunks = calculate_chunks(offset) - for c in chunks: - tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) - return tensor_groups - - -def move(dst_tensors, dst_indices, src_tensors): - bs_dim = 0 - num_indices = dst_indices.size(0) - for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): - if src_t.size(bs_dim) != num_indices: - src_t = torch.narrow(src_t, bs_dim, 0, num_indices) - dst_t.index_copy_(bs_dim, dst_indices, src_t) - htorch.core.mark_step() - - -def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): - for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): - move(dst_tensors, dst_indices, src_tensors) - - -@torch_compile_for_eager -def extend_tensor(tensor, padding, dim): - result = torch.cat([tensor, padding], dim=dim) - htorch.core.mark_step() - return result - - -@torch_compile_for_eager -def extend_batch(tensors, target_bs, dim): - diff = target_bs - tensors[0].size(dim) - # TODO: add support for shrinking bs - if diff <= 0: - return tensors - shape = list(tensors[0].shape) - shape[dim] = diff - padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) - tensors = [extend_tensor(t, padding, dim) for t in tensors] - return tensors - - -def grouped_extend_batch(tensor_groups, target_bs, bs_dims): - tensor_groups = [ - extend_batch(tensors, target_bs, dim) - for tensors, dim in zip(tensor_groups, bs_dims) - ] - return tensor_groups - - -@torch_compile_for_eager -def merge(tensor_group): - tensor_group = [torch.stack(tensor_group)] - htorch.core.mark_step() - return tensor_group - - -@torch_compile_for_eager -def split(tensor_group, clone_data): - tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] - if clone_data: - tensor_group = [t.clone() for t in tensor_group] - htorch.core.mark_step() - return tensor_group - - -def remove_kv_cache_from_output(module): - orig_fwd = module.forward - - @wraps(orig_fwd) - def forward(*args, **kwargs): - if kwargs["past_key_values"] is not None: - kwargs["return_dict"] = False - output = orig_fwd(*args, **kwargs) - first_value, second_value, *_ = output - if first_value.nelement() < 2: - return second_value - else: - return first_value - else: - kwargs["return_dict"] = True - return orig_fwd(*args, **kwargs) - - module.forward = forward - return module - - -@dataclass -class CausalLMRequest: - idx: int - data: generate_pb2.Request - input_length: int - prefix_offset: int - read_offset: int - stopping_criteria: StoppingCriteria - - all_input_ids: torch.Tensor - - @classmethod - def from_pb( - cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase - ): - return cls( - idx=idx, - data=data, - input_length=None, - prefix_offset=None, - read_offset=None, - stopping_criteria=StoppingCriteria.from_pb( - data.stopping_parameters, tokenizer - ), - all_input_ids=None, - ) - - def update_idx(self, new_idx): - prev = self.idx - self.idx = new_idx - return (new_idx, prev) - - -@dataclass -class CausalLMBatch(Batch): - batch_id: int - requests: List[CausalLMRequest] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - past_key_values: Optional[List[Tuple]] - merged_kv_cache: bool - - # Lengths of all generations present in the batch - input_length: int - - # Generation helpers - next_token_chooser: HeterogeneousNextTokenChooser - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - input_length: int - - # Past metadata - logits = None - past = None - - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.data.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - def detach_kv_cache(self): - past_keys = [past[0] for past in self.past_key_values] - past_values = [past[1] for past in self.past_key_values] - del self.past_key_values - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - # TODO: Add support for models that don't store kv_cache in a list - self.past_key_values = list(zip(past_keys, past_values)) - - def merge_kv_cache_if_needed(self, target_bs, offset): - pad_needed = self.seq_length < MAX_TOTAL_TOKENS - shift_needed = offset != 0 - expand_needed = target_bs > self.batch_size - # Very simple heuristic to determine whether we should merge tensors - # this needs tuning for other models/scenarios - small_bs = len(self.past_key_values) > self.batch_size - if ( - not self.merged_kv_cache - and small_bs - and (pad_needed or shift_needed or expand_needed) - ): - past_keys, past_values = self.detach_kv_cache() - past_keys = merge(past_keys) - past_values = merge(past_values) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = True - - def split_kv_cache_if_needed(self, clone_data): - if self.merged_kv_cache: - past_keys, past_values = self.detach_kv_cache() - past_keys = split(past_keys, clone_data) - past_values = split(past_values, clone_data) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = False - - def get_tensor_groups(self): - past_keys, past_values = self.detach_kv_cache() - seq_dim = -1 - key_dim = -2 if self.keys_head_dim_last else -1 - value_dim = -2 - tensors = [ - [self.input_ids], - [self.attention_mask], - [self.position_ids], - past_keys, - past_values, - ] - # We don't need to align position_ids - seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] - bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) - return tensors, seq_dims, bs_dims - - def set_tensor_groups(self, tensors): - self.input_ids = tensors.pop(0)[0] - self.attention_mask = tensors.pop(0)[0] - self.position_ids = tensors.pop(0)[0] - past_keys = tensors.pop(0) - past_values = tensors.pop(0) - self.attach_kv_cache(past_keys, past_values) - - def realign(self, target_bs, offset, pad_token_id): - tensors, seq_dims, _ = self.get_tensor_groups() - tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) - tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) - self.set_tensor_groups(tensors) - - def expand_bs(self, target_bs): - tensors, _, bs_dims = self.get_tensor_groups() - tensors = grouped_extend_batch(tensors, target_bs, bs_dims) - self.set_tensor_groups(tensors) - - def used_indices(self): - return [req.idx for req in self.requests] - - def update_indices(self, new_indices): - for req, new_idx in zip(self.requests, new_indices): - req.idx = new_idx - return self.used_indices() - - def free_indices_generator(self): - used = set(req.idx for req in self.requests) - return (i for i in range(self.batch_size) if i not in used) - - def move_data(self, src_batches): - dst_tensors, _, dst_dims = self.get_tensor_groups() - free_indices_gen = self.free_indices_generator() - for src_b in src_batches: - dst_indices = to_tensor_indices( - src_b.update_indices(free_indices_gen), self.input_ids.device - ) - src_tensors, _, src_dims = src_b.get_tensor_groups() - grouped_move(dst_tensors, dst_indices, src_tensors) - self.set_tensor_groups(dst_tensors) - - @classmethod - def recombine( - cls, batches: List["CausalLMBatch"], pad_token_id: int - ) -> "CausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - new_bs = round_up_batch(total_requests) - - batch_id = batches[0].batch_id - device = batches[0].input_ids.device - - input_lengths = [b.input_length for b in batches] - max_input_length = max(input_lengths) - offsets = [max_input_length - b.input_length for b in batches] - - cur_padding = [b.right_padding for b in batches] - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - - moves_needed = [ - total_requests - len(b) if b.batch_size == new_bs else total_requests - for b in batches - ] - dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = batches[dst_batch_idx].batch_size < new_bs - - # TODO: Add support for changing max seq len, i.e. due to output length bucketing - # FIXME: max_seq_len for non optimized code - if len(batches) > 1: - scenario = "CONCAT" - elif reshape: - scenario = "RESHAPE" - elif cur_padding[dst_batch_idx] <= 0: - scenario = "SHIFT" - offsets = [ - biggest_single_chunk(b.max_input_length - max_input_length) - for b in batches - ] - max_input_length = max_input_length + offsets[dst_batch_idx] - else: - # Nothing to do - return batches[0] - - dbg_trace( - scenario, - f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}" - f" offsets:{offsets}" - f" input_lengths:{input_lengths}" - f" cur_padding:{cur_padding}" - f" dst_batch:{dst_batch_idx}", - ) - - grouped_requests = [[req for req in batch.requests] for batch in batches] - flat_requests = list(itertools.chain(*grouped_requests)) - - for i in range(len(batches)): - target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size - batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) - batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[i].split_kv_cache_if_needed(i == dst_batch_idx) - batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data( - [batches[i] for i in range(len(batches)) if i != dst_batch_idx] - ) - - top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens.extend([-1] * (new_bs - total_requests)) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - parameters = [r.data.parameters for r in flat_requests] - # append the dummy parameters for dummy requests - batch_size = batches[dst_batch_idx].batch_size - parameters = pad_next_token_chooser_parameters(parameters, batch_size) - - # update past grammar states - fsm_grammar_states = [0] * batch_size - for batch in batches: - for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = ( - batch.next_token_chooser.fsm_grammar_states[i] - ) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device, - batches[dst_batch_idx].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - input_ids = batches[dst_batch_idx].input_ids - attention_mask = batches[dst_batch_idx].attention_mask - position_ids = batches[dst_batch_idx].position_ids - past_key_values = batches[dst_batch_idx].past_key_values - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=flat_requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") - requests = [ - CausalLMRequest.from_pb(idx, req, tokenizer) - for idx, req in enumerate(pb.requests) - ] - inputs = [] - top_n_tokens = [] - - # Parse batch - max_truncation = 0 - for i, r in enumerate(pb.requests): - inputs.append(concat_text_chunks(r.input_chunks.chunks)) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - - max_input_length = max_truncation - if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: - max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up_batch(len(requests)) - missing_inputs = new_bs - len(inputs) - dummy_inputs = ["?"] * missing_inputs - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - tokenized_inputs = tokenizer( - inputs + dummy_inputs, - return_tensors="pt", - padding="longest", - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ) - - input_len = tokenized_inputs["input_ids"].shape[1] - # Round up sequence length - bucket_size = max_input_length - left_padding = max_input_length - input_len - if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert ( - PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length - ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up_seq( - input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE - ) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - - old_bs = len(requests) - top_n_tokens.extend([-1] * (new_bs - old_bs)) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - htorch.core.mark_step() - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}") - request_ids = set(request_ids) - self.requests = [req for req in self.requests if req.data.id in request_ids] - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["CausalLMBatch"], pad_token_id: int = 0 - ) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id) - - def __len__(self): - return len(self.requests) - - @property - def max_input_length(self): - return max(req.input_length for req in self.requests) - - @property - def batch_size(self): - return self.attention_mask.size(0) - - @property - def seq_length(self): - return self.attention_mask.size(1) - - @property - def right_padding(self): - return self.seq_length - self.input_length - - # Maximum number of tokens this batch will grow to - @property - def max_tokens(self): - max_total_tokens = self.attention_mask.size(1) - return len(self.requests) * max_total_tokens - - -class CausalLM(Model): - def __init__( - self, - model_id: str, - model_class: Optional[Type[torch.nn.Module]] = None, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - default_dtype=torch.float16, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - config_class=AutoConfig, - batch_class=CausalLMBatch, - ): - if speculator: - raise RuntimeError("Speculator decoding is not enabled for AutoModel") - - self.prev_bs = 0 - self.quantize = quantize - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - # Get weight files - weight_files(model_id, revision=revision, extension=".safetensors") - - if world_size > 1: - os.environ.setdefault( - "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" - ) - model = self.get_deepspeed_model(model_id, dtype, revision) - model = hq_env.prepare_model_for_quantization(model) - else: - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained(model_id) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs, - ) - model = hq_env.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = ( - os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" - - if model.config.model_type not in [ - "gpt_bigcode" - ]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() - model = remove_kv_cache_from_output(model) - - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace("TORCH COMPILE", "Torch compiling of model") - model.model = torch.compile( - model.model, - backend="hpu_backend", - options={"keep_input_mutations": True}, - ) - - model = hq_env.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - self.kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in [ - "llama", - "mistral", - "starcoder2", - "qwen2", - "falcon", - "gpt_bigcode", - ]: - if model.config.model_type not in ["falcon", "gpt_bigcode"]: - self.kwargs["attn_softmax_bf16"] = True - - if model.config.model_type not in ["gpt_bigcode"]: - self.kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": - self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": - self.kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = ( - int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_steps = ( - int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes, - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - def get_deepspeed_model( - self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = {"revision": revision} - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( - world_size, rank, local_rank - ) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) - else: - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype=dtype, **model_kwargs - ) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [ - str(f) - for f in weight_files( - model_id, revision=revision, extension=".safetensors" - ) - ] - data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, checkpoints_json) - checkpoints_json.flush() - - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return {"type": rope_scaling, "factor": float(rope_factor)} - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode( - all_input_ids[read_offset:], skip_special_tokens=False - ) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) - - def forward( - self, - input_ids, - attention_mask, - position_ids, - token_idx, - past_key_values: Optional[List[Tuple]] = None, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "token_idx": token_idx, - } - - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama": - kwargs["lazy_mode"] = LAZY_MODE == 1 - - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - if bypass_hpu_graph is not None: - kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - - if past_key_values is not None and self.model.config.model_type not in [ - "gpt_bigcode" - ]: - return self.model.forward(**kwargs) - else: - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: List[CausalLMBatch] - ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = ( - batch.attention_mask.shape[-1] - batch.right_padding - ) - token_idx = torch.tensor(token_idx_scalar).to(self.device) - - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, - logits[:, input_length - 1 : input_length, :].squeeze(-2), - self.speculate, - ) - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append( - { - "next_token_ids": next_token_ids, - "next_token_logprobs": next_token_logprobs, - } - ) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append( - { - "req": req, - "prev_req_idx": req.idx, - "batch_id": batch_id, - "seed": batch.next_token_chooser.seeds[req_idx], - "do_sample": batch.next_token_chooser.do_sample[req_idx], - "top_n_tokens": batch.top_n_tokens[req_idx], - "top_token_ids": batch_top_token_ids[req_idx], - "top_token_logprobs": batch_top_token_logprobs[req_idx], - "grammar_state": batch.next_token_chooser.fsm_grammar_states[ - req.idx - ], - } - ) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # Adjust lengths - batch.input_length += 1 - - # Update position_ids - if prefill: - batch.position_ids = ( - torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - ) - else: - batch.position_ids += 1 - # Update past key values - if prefill or self.model.config.model_type in ["gpt_bigcode"]: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) - - scenario = "PREFILL" if prefill else "GENERATE" - if ( - self.enable_hpu_graph - and self.limit_hpu_graph - and round_up_batch(batch.batch_size) != self.prev_bs - ): - self.model.clear_cache() - self.prev_bs = round_up_batch(batch.batch_size) - dbg_trace( - scenario, - f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", - ) - assert batch.right_padding > 0, "No more room for next token!" - - # Execute batch - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - token_idx = torch.tensor( - batch.attention_mask.shape[-1] - batch.right_padding - ).to(self.device) - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - logits = self.forward( - input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - if self.model.config.model_type in ["gpt_bigcode"]: - batch.logits, batch.past = logits - else: - batch.logits = logits - - htorch.core.mark_step() - - start_decode = time.time_ns() - - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch["next_token_logprobs"] = prev_batch[ - "next_token_logprobs" - ].tolist() - prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data["req"] - i = req_data["prev_req_idx"] - prev_batch_id = req_data["batch_id"] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] - next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data["do_sample"] - seed = req_data["seed"] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data["top_n_tokens"] - top_token_ids = req_data["top_token_ids"] - top_token_logprobs = req_data["top_token_logprobs"] - grammar_state = req_data["grammar_state"] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if ( - is_tokenizer_transparent(self.tokenizer) - and len(stopping_criteria.stop_sequence_criterias) == 0 - ): - next_token_text = "" - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[ - new_input_length - - stopping_criteria.current_tokens : new_input_length, - 0, - ] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if ( - self.step - > self.profiling_wait_steps - + self.profiling_warmup_steps - + self.profiling_steps - ): - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def generate_warmup_batch(self, request, seq_len, batch_size): - batch = copy.deepcopy(request.batch) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) - - def warmup( - self, request: generate_pb2.WarmupRequest - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - assert ( - MAX_BATCH_SIZE is not None - ), "MAX_BATCH_SIZE is not set, it should be set in the launcher" - MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens - logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}") - logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}") - MAX_TOTAL_TOKENS = request.max_total_tokens - - batch = self.batch_type.from_pb( - request.batch, self.tokenizer, self.dtype, self.device - ) - max_prefill_batch_size = batch.input_ids.shape[0] - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch]) - except Exception: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - del prefill_batch - - # Warmup prefill batch_size - max_input_tokens = request.max_input_tokens - max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE)) - prefill_batch_size_list = [ - BATCH_SIZE_EXPONENT_BASE**exp - for exp in range( - 0, - max_exp + 1, - ) - ] - prefill_seqlen_list = iterate_powers_of_base( - max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE - ) - prefill_seqlen_list.append(max_input_tokens) - prefill_batch_size_list.sort(reverse=True) - prefill_seqlen_list.sort(reverse=True) - try: - for batch_size in prefill_batch_size_list: - for seq_len in prefill_seqlen_list: - batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - except Exception: - prefill_batch_size_list.sort() - prefill_seqlen_list.sort() - raise RuntimeError( - f"Not enough memory to run following prefill batch_size." - f"Prefill batch size list:{prefill_batch_size_list}" - f"Prefill sequence length list:{prefill_seqlen_list}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - prefill_seqlen_list.sort() - prefill_batch_size_list.sort() - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill warmup successfully.\n" - f"Prefill batch size list:{prefill_batch_size_list}\n" - f"Prefill sequence length list:{prefill_seqlen_list}\n" - f"Memory stats: {mem_stats} " - ) - - max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE)) - decode_batch_size_list = [ - BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1) - ] - decode_batch_size_list.sort(reverse=True) - - try: - for batch_size in decode_batch_size_list: - batches = [] - iters = math.floor(batch_size / max_prefill_batch_size) - for i in range(iters): - batch = self.generate_warmup_batch( - request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size - ) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) - - if batch_size % max_prefill_batch_size != 0: - batch = self.generate_warmup_batch( - request, - PAD_SEQUENCE_TO_MULTIPLE_OF - 1, - batch_size % max_prefill_batch_size, - ) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) - - _, decode_batch, _ = self.generate_token(batches) - _, decode_batch, _ = self.generate_token([decode_batch]) - del decode_batch - batches.clear() - - except Exception: - raise RuntimeError( - f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." - f"You need to decrease `--max-batch-total-tokens`" - ) - - decode_batch_size_list.sort() - max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{decode_batch_size_list}\n" - f"Memory stats: {mem_stats} " - ) - - max_input_tokens = max_input_tokens - max_total_tokens = MAX_TOTAL_TOKENS - - return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py deleted file mode 100644 index 00ecdf95..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ /dev/null @@ -1,467 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Llava-NeXT model.""" - -from typing import List, Optional, Union - -import torch -import torch.utils.checkpoint -import numpy as np - -from loguru import logger -from transformers.models.llava_next.modeling_llava_next import ( - unpad_image, -) -from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration -from transformers.image_processing_utils import select_best_resolution - - -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise ValueError("grid_pinpoints should be a list of tuples or lists") - - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - -# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type {type(image_size)} with value {image_size}" - ) - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches - - -class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - image_sizes: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, - ): - - if token_idx is not None: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - logits = outputs[0] - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 - def pack_image_features( - self, - image_features, - image_sizes, - vision_feature_select_strategy, - image_newline=None, - ): - """ - Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. - - Args: - image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) - List of image feature tensor, each contains all the visual feature of all patches. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_select_strategy (`str`) - The feature selection strategy used to select the vision feature from the vision backbone. - image_newline (`torch.Tensor` of shape `(embed_dim)`) - New line embedding vector. - Returns: - image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) - feature_lens (`List[int]`) - token length of each image in image_features - """ - new_image_features = [] - feature_lens = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - if ( - np.prod(image_feature.shape) - % (num_patch_height * num_patch_width * height * width) - != 0 - and vision_feature_select_strategy == "default" - ): - logger.warning_once( - "Image feature shape does not line up with the provided patch size. " - "You may be using the `default` vision_feature_select_strategy with a" - " visual encoder that does not have CLS." - ) - - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - if image_newline is not None: - image_feature = torch.cat( - ( - image_feature, - image_newline[:, None, None] - .expand(*image_feature.shape[:-1], 1) - .to(image_feature.device, image_feature.dtype), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if image_newline is not None: - image_feature = torch.cat( - (image_feature, image_newline[None].to(image_feature)), dim=0 - ) - new_image_features.append(image_feature) - feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor( - feature_lens, dtype=torch.long, device=image_features.device - ) - return image_features, feature_lens - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 - def get_image_features( - self, - pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, - ): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) - The tensors corresponding to the input images. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` - Returns: - image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches - and are of shape `(num_patches, image_length, embed_dim)`). - """ - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - if pixel_values.dim() == 5: - # stacked if input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError( - f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions" - ) - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_image_feature = image_features.hidden_states[vision_feature_layer] - else: - hs_pool = [ - image_features.hidden_states[layer_idx] - for layer_idx in vision_feature_layer - ] - selected_image_feature = torch.cat(hs_pool, dim=-1) - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - image_features = torch.split(image_features, image_num_patches, dim=0) - return image_features - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if ( - past_key_values is None - and pixel_values is not None - and input_ids.shape[1] != 1 - ): - vision_feature_select_strategy = kwargs.get( - "vision_feature_select_strategy", None - ) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) - - special_image_mask = ( - input_ids == self.config.image_token_index - ).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_features = image_features.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds.masked_scatter( - special_image_mask, image_features - ) - - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = ( - torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py deleted file mode 100644 index 6ba0ffff..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py +++ /dev/null @@ -1,292 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mllama model.""" - -from typing import Optional, Tuple, List, Union - -import torch -import torch.utils.checkpoint - -from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration -from optimum.habana.transformers.models.mllama.modeling_mllama import ( - _prepare_cross_attention_mask, -) -from transformers.modeling_outputs import CausalLMOutputWithPast - - -class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration): - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, - **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - """ - Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 - The only differences are: - - add token_idx input - - add use_flash_attention and flash_attention_recompute - """ - full_text_row_masked_out_mask = kwargs.get( - "full_text_row_masked_out_mask", None - ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - labels=labels, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - logits = outputs[0] - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - num_logits_to_keep=None, - **kwargs, - ): - """ - Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 - The only differences are: - - add token_idx handling - - add bucket_internal handling - - add use_flash_attention and flash_attention_recompute - """ - - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - cross_attention_mask=cross_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - position_ids = kwargs.get("position_ids", None) - output_attentions = kwargs.get("output_attentions", None) - output_hidden_states = kwargs.get("output_hidden_states", None) - return_dict = kwargs.get("return_dict", None) - labels = kwargs.get("labels", None) - cross_attention_states = kwargs.get("cross_attention_states", None) - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - bucket_internal = kwargs.get("bucket_internal", None) - - if past_key_values is not None: - if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - elif inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - elif bucket_internal and token_idx is not None: - # for the 1st token we can slice the inputs till token idx for the fwd pass. - input_ids = input_ids[:, :token_idx] - attention_mask = attention_mask[:, :token_idx] - if cross_attention_mask is not None: - cross_attention_mask = cross_attention_mask[:, :token_idx, ...] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.index_select( - position_ids, 1, token_idx - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone( - memory_format=torch.contiguous_format - ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - use_flash_attention=use_flash_attention, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - token_idx=token_idx, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None: - if cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - elif past_key_values is not None: - if token_idx is not None: - cross_attention_mask = torch.index_select( - cross_attention_mask, -2, token_idx - 1 - ) - full_text_row_masked_out_mask = torch.index_select( - full_text_row_masked_out_mask, -2, token_idx - 1 - ) - else: - cross_attention_mask = cross_attention_mask[:, :, -1:] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, -1: - ] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = { - "input_ids": input_ids.clone(memory_format=torch.contiguous_format), - "inputs_embeds": None, - } - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - # keep cache_position implementation as None for HPU - cache_position = None - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "return_dict": kwargs.get("return_dict"), - "full_text_row_masked_out_mask": full_text_row_masked_out_mask, - "use_flash_attention": use_flash_attention, - "cross_attention_mask": cross_attention_mask, - "cross_attention_states": cross_attention_states, - "output_attentions": output_attentions, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index a80a86a7..ac1578e9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -54,7 +54,8 @@ import habana_frameworks.torch as htorch # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py from typing import Union from transformers.feature_extraction_utils import BatchFeature -from transformers.image_utils import ImageInput, VideoInput +from transformers.image_utils import ImageInput +from transformers.video_utils import VideoInput from transformers.processing_utils import ( ProcessingKwargs, ProcessorMixin, diff --git a/backends/gaudi/server/text_generation_server/models/galactica.py b/backends/gaudi/server/text_generation_server/models/galactica.py deleted file mode 100644 index 7c4e462c..00000000 --- a/backends/gaudi/server/text_generation_server/models/galactica.py +++ /dev/null @@ -1,156 +0,0 @@ -import re -import torch -import torch.distributed - - -from transformers import ( - PreTrainedTokenizerBase, -) -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, -) -from text_generation_server.utils.chunks import concat_text_chunks - -# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py - -# we split individual characters inside special tokens like [START_DNA] -CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") - -# token added to implement a custom sequence tokenization. This token is added at -# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance -# that they do not occur in the corpus. The digits are escaped so that the token does not appear -# literally in the source code in case we ever include it in the training data. -SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" - - -def _insert_split_marker(m: re.Match): - """ - Applies split marker based on a regex match of special tokens such as - [START_DNA]. - Parameters - ---------- - n : str - Input text to split - Returns - ---------- - str - the text with the split token added - """ - start_token, _, sequence, end_token = m.groups() - sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) - return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" - - -def escape_custom_split_sequence(text): - """ - Applies custom splitting to the text for GALILEO's tokenization - Parameters - ---------- - text : str - Input text to split - Returns - ---------- - str - the text with the split token added - """ - return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) - - -# END CREDIT - - -class GalacticaCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "GalacticaCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - top_n_tokens = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append( - escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) - ) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding=True, - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(0) - read_offsets.append(input_len) - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - max_tokens = len(inputs) * max_input_length + max_decode_tokens - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index cd221e14..cdde67ca 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -4,14 +4,14 @@ from loguru import logger from text_generation_server.utils.log import log_master REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} -ATTENTION = os.getenv("ATTENTION", "default") +ATTENTION = os.getenv("ATTENTION", "paged") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "default"} +_expected = {"paged"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py deleted file mode 100644 index 98d7352a..00000000 --- a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py +++ /dev/null @@ -1,882 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import time - -from dataclasses import dataclass -from opentelemetry import trace -from transformers import ( - AutoConfig, - AutoProcessor, - AutoTokenizer, - PreTrainedTokenizerBase, - ProcessorMixin, -) -from typing import Optional, Tuple, List, Type, Dict - -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -import torch.distributed -from text_generation_server.models.custom_modeling.idefics_modeling import ( - IdeficsForVisionText2Text, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.quantization import get_loader - -tracer = trace.get_tracer(__name__) - - -@dataclass -class IdeficsCausalLMBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - pixel_values: Optional[torch.Tensor] - image_hidden_states: Optional[torch.Tensor] - image_attention_mask: Optional[torch.Tensor] - past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "IdeficsCausalLMBatch": - raise NotImplementedError - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor: ProcessorMixin, # Hack - config, - dtype: torch.dtype, - device: torch.device, - ) -> "IdeficsCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.input_chunks.chunks) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - # TODO Check impact on idefics - prompts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompt = [] - for chunk in inp: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - prompt.append(chunk.text) - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - prompt.append(image) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - prompts.append(prompt) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - tokenized_inputs = processor( - prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - # TODO Check impact on idefics - # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append( - input_len - 5 - ) # To decode without potential fallbacks errors - read_offsets.append( - input_len - ) # To decode without potential fallbacks errors - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - pixel_values = tokenized_inputs.get("pixel_values", None) - image_hidden_states = None - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - # Do the same for image_attention_mask - if pixel_values is None: - image_attention_mask = None - else: - image_attention_mask = input_ids.new_zeros( - ( - pb.size, - max_input_length + padding_right_offset, - pixel_values.size(1), - ) - ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ - "image_attention_mask" - ] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split( - 1, dim=1 - ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list - - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - # Do the same for pixel_values and image_attention_mask - pixel_values = self.pixel_values[keep_indices] - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] - if self.image_hidden_states is None: - image_hidden_states = None - else: - image_hidden_states = self.image_hidden_states[keep_indices] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) is tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.pixel_values = pixel_values - self.image_hidden_states = image_hidden_states - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["IdeficsCausalLMBatch"] - ) -> "IdeficsCausalLMBatch": - # It adds new requests to the batch - # Used for padding - total_batch_size = 0 - max_input_length = 0 - max_num_images = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - max_num_images = max(max_num_images, batch.pixel_values.size(1)) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - pixel_values = None - image_hidden_states = None - image_attention_mask = None - past_key_values = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) - ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values - ) - - if image_attention_mask is None: - image_attention_mask = batch.image_attention_mask.new_zeros( - ( - total_batch_size, - max_input_length + padding_right_offset, - max_num_images, - ) - ) - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values[0], tuple): - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] - ) - else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] - ) - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values - - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) - - def __len__(self): - return len(self.requests) - - -class IdeficsCausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - self.process_group, rank, world_size = initialize_torch_distributed() - device = torch.device("hpu") - dtype = torch.bfloat16 if dtype is None else dtype - self.device, self.dtype = device, dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - config.vision_config.quantize = quantize - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) - - model = IdeficsForVisionText2Text(config, weights) - - self.config = config - - torch.distributed.barrier(group=self.process_group) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[IdeficsCausalLMBatch]: - return IdeficsCausalLMBatch - - def forward( - self, - input_ids, - attention_mask, - position_ids, - pixel_values, - image_hidden_states, - image_attention_mask, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_hidden_states": image_hidden_states, - "image_attention_mask": image_attention_mask, - "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, - } - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - outputs, speculative_logits = self.model.forward(**kwargs) - return ( - outputs.logits, - speculative_logits, - outputs.past_key_values, - outputs.image_hidden_states, - ) - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: IdeficsCausalLMBatch - ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - if batch.image_attention_mask is None: - image_attention_mask = None - else: - if batch.input_ids.size(1) == 1: - # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), - # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension - # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated - # token need to attend to the encoder hidden states (i.e. the vision encoder) - # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[ - :, -(batch.padding_right_offset + 1) - ].unsqueeze(1) - else: - image_attention_mask = batch.image_attention_mask[ - :, : -batch.padding_right_offset - ] - - logits, speculative_logits, past, image_hidden_states = self.forward( - input_ids=batch.input_ids, - attention_mask=attention_mask, - position_ids=batch.position_ids, - pixel_values=batch.pixel_values, - image_hidden_states=batch.image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=batch.past_key_values, - ) - # Hardcoded remove image tokens - logits[:, 32000:32001] = torch.finfo(logits.dtype).min - - start_decode = time.time_ns() - - # Results - generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( - next_token_id_squeezed.item() - ) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 - - # Update past key values - batch.past_key_values = past - batch.image_hidden_states = image_hidden_states - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/mamba.py b/backends/gaudi/server/text_generation_server/models/mamba.py deleted file mode 100644 index f6dcde68..00000000 --- a/backends/gaudi/server/text_generation_server/models/mamba.py +++ /dev/null @@ -1,814 +0,0 @@ -import torch -import torch.distributed -from transformers import AutoTokenizer, PreTrainedTokenizerBase -from typing import Optional -from text_generation_server.models.custom_modeling.mamba_modeling import ( - MambaConfig, -) -from loguru import logger -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL -import time -from text_generation_server.models.custom_modeling.mamba_modeling import ( - MambaModel, - InferenceParams, -) -from text_generation_server.models import Model -from typing import Any, List, Tuple, Type, Dict -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.quantization import get_loader -from text_generation_server.utils.tokens import batch_top_tokens, Sampling -from dataclasses import dataclass -from text_generation_server.utils import NextTokenChooser, StoppingCriteria - - -def new_inference_params( - n_blocks: int, - batch_size: int, - d_inner: int, - d_conv: int, - d_state: int, - seqlen_offset: int, - dtype: torch.dtype, - device: torch.device, -): - max_seqlen = 0 - conv_states = torch.zeros( - ( - n_blocks, - batch_size, - d_inner, - d_conv, - ), - device=device, - dtype=dtype, - ) - ssm_states = torch.zeros( - ( - n_blocks, - batch_size, - d_inner, - d_state, - ), - device=device, - dtype=dtype, - ) - inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_offset, - conv_states=conv_states, - ssm_states=ssm_states, - ) - return inference_params - - -@dataclass -class MambaBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - # Inference params - inference_params: Optional[Dict[str, Any]] = None - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "MambaBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(concat_text_chunks(r.input_chunks.chunks)) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding=True, - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - input_ids = tokenized_inputs["input_ids"] - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - # past_input_ids=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - indices = [] - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - indices.append(idx) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - - top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.top_n_tokens = top_n_tokens - self.top_n_tokens_tensor = top_n_tokens_tensor - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - # TODO - # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. - self.inference_params.conv_states = self.inference_params.conv_states[ - :, indices - ] - self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] - return self - - @classmethod - def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - max_tokens = 0 - seqlen_offset = 0 - - (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape - (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape - dtype = batches[0].inference_params.conv_states.dtype - device = batches[0].inference_params.conv_states.device - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=total_batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=device, - dtype=dtype, - ) - - # Batch tensors - input_ids = None - top_n_tokens_tensor = None - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - top_n_tokens.extend(batch.top_n_tokens) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - inference_params.max_seqlen = max( - inference_params.max_seqlen, batch.inference_params.max_seqlen - ) - assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" - inference_params.seqlen_offset = max( - inference_params.seqlen_offset, batch.inference_params.seqlen_offset - ) - - inference_params.conv_states[:, start_index:end_index] = ( - batch.inference_params.conv_states - ) - inference_params.ssm_states[:, start_index:end_index] = ( - batch.inference_params.ssm_states - ) - - start_index = end_index - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - inference_params=inference_params, - ) - - def __len__(self): - return len(self.requests) - - -class Mamba(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - self.process_group, _rank, world_size = initialize_torch_distributed() - if world_size > 1: - raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") - self.cuda_graphs = {} - if torch.cuda.is_available(): - device = torch.device("cuda") - # Bf16 is important. In f16 accumulations in the matmul are causing - # differences while the server is under load. - # This is detectable by the integration load test - dtype = torch.bfloat16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - "EleutherAI/gpt-neox-20b", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = MambaConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) - model = MambaModel(config, weights) - torch.distributed.barrier(group=self.process_group) - super(Mamba, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - @property - def batch_type(self) -> Type[MambaBatch]: - return MambaBatch - - def warmup(self, batch) -> Optional[int]: - # TODO: implement warmup for Mamba if needed - if CUDA_GRAPHS: - if self.speculate is None or self.speculate == 0: - try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") - # Warmup cuda graphs - for bs in CUDA_GRAPHS: - self.cuda_graph_warmup(bs) - except Exception: - logger.exception("Decode cuda graph warmup failed") - else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - - return None - - def cuda_graph_warmup(self, batch_size: int): - input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) - n_blocks = len(self.model.blocks) - - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - # Inner takes the expand multiplication - d_inner = self.model.config.d_inner - - # Important seqlen_offset to go through the update mecanism with the state - seqlen_offset = 1 - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - - graph = torch.cuda.CUDAGraph() - - torch.cuda.synchronize() - # Run once outside to warmup - self.model.forward(input_ids=input_ids, inference_params=inference_params) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - logits, speculative_logits = self.model.forward( - input_ids=input_ids, inference_params=inference_params - ) - torch.cuda.synchronize() - graph_dict = { - "input_ids": input_ids, - "inference_params": inference_params, - "graph": graph, - "logits": logits, - "speculative_logits": speculative_logits, - } - self.cuda_graphs[batch_size] = graph_dict - - def tunableop_warmup(self, batch_size: int, seqlen: int): - input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) - n_blocks = len(self.model.blocks) - - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - # Inner takes the expand multiplication - d_inner = self.model.config.d_inner - - # Important seqlen_offset to go through the update mecanism with the state - seqlen_offset = 1 - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=seqlen, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - - self.model.forward(input_ids=input_ids, inference_params=inference_params) - - def forward( - self, input_ids: torch.Tensor, inference_params: Any - ) -> Tuple[torch.Tensor, torch.Tensor]: - bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - is_prefill = inference_params is None or inference_params.seqlen_offset == 0 - - if is_prefill or cuda_graph is None: - return self.model( - input_ids, - inference_params=inference_params, - ) - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][:bs] = input_ids - cuda_graph["inference_params"].conv_states[ - :, :bs - ] = inference_params.conv_states - cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states - - # Replay the graph - cuda_graph["graph"].replay() - - inference_params.conv_states.copy_( - cuda_graph["inference_params"].conv_states[:, :bs] - ) - inference_params.ssm_states.copy_( - cuda_graph["inference_params"].ssm_states[:, :bs] - ) - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None - ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits - - def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: - start = time.time_ns() - input_ids = ( - batch.input_ids - ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids - - batch_size, max_seqlen = input_ids.shape - # Inference params - - if batch.inference_params is None: - # 0 is important here - seqlen_offset = 0 - n_blocks = len(self.model.blocks) - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - d_inner = self.model.config.d_inner - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - batch.inference_params = inference_params - - # Forward pass - logits, speculative_logits = self.forward( - input_ids, inference_params=batch.inference_params - ) - - # batch.inference_params = new_inference_params - # Results - generations: List[Generation] = [] - stopped = True - - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - torch.log_softmax(logits[:, -1], -1), - accepted_ids, - ) - - start_decode = time.time_ns() - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - batch.top_n_tokens, - batch_top_token_ids, - batch_top_token_logprobs, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - top_n_tokens, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[ - i - ].advance_grammar(next_token_id_squeezed.item()) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/starcoder.py b/backends/gaudi/server/text_generation_server/models/starcoder.py deleted file mode 100644 index 6c6ca2cf..00000000 --- a/backends/gaudi/server/text_generation_server/models/starcoder.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from dataclasses import dataclass -from typing import List, Optional, Type - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch - - -@dataclass -class StarCoderCausalLMBatch(CausalLMBatch): - past_key_values: Optional[List[torch.Tensor]] - - def detach_kv_cache(self): - past_keys = [] - past_values = [] - last_dim = int(self.past_key_values[0].size(dim=-1) / 2) - for key_value in self.past_key_values: - past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) - past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) - del self.past_key_values - - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - self.past_key_values = [ - torch.cat((key, value), dim=-1) - for key, value in zip(past_keys, past_values) - ] - - -class StarCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - - super(StarCoder, self).__init__( - model_id=model_id, - revision=revision, - dtype=dtype, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return StarCoderCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py deleted file mode 100644 index 6929b2ef..00000000 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ /dev/null @@ -1,1609 +0,0 @@ -import json -import re -import torch -import os -import time -import math -from PIL import Image -from io import BytesIO -from opentelemetry import trace -from loguru import logger -from typing import Iterable, Optional, Tuple, List, Type, Dict -import tempfile -import copy -from text_generation_server.models import Model -from transformers import PreTrainedTokenizerBase -from text_generation_server.utils import weight_files -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import ( - CausalLMBatch, - CausalLMRequest, - remove_kv_cache_from_output, -) - -from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, -) - -from transformers import AutoProcessor -import text_generation_server.habana_quantization_env as hq_env -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - make_tokenizer_optional, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from optimum.habana.utils import get_hpu_memory_stats -from optimum.habana.checkpoint_utils import get_ds_injection_policy - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from optimum.habana.checkpoint_utils import model_on_meta - -from text_generation_server.utils.speculate import get_speculate -from text_generation_server.models.types import ( - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.debug import dbg_trace - -tracer = trace.get_tracer(__name__) - -IDEFICS2_FAKE_TOKEN = "" -IDEFICS2_IMAGE_TOKEN = "" - - -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") -BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048)) -MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) - - -PREFILL_WARMUP_BATCH_SIZE_LIST = [] -PREFILL_WARMUP_SEQLEN_LIST = [] -DECODE_WARMUP_BATCH_SIZE_LIST = [] -CROSS_ATTENTION_LAYERS = [] - - -def round_up(warmup_list: list, num): - i = 0 - for i in warmup_list: - if num <= i: - break - return i if i > 0 else num - - -def split(string) -> List[Dict[str, str]]: - parts = [] - cursor = 0 - for pattern in IMAGES.finditer(string): - start = pattern.start() - if start != cursor: - parts.append({"type": "text", "content": string[cursor:start]}) - - parts.append({"type": "image", "content": pattern.group(1)}) - cursor = pattern.end() - - if cursor != len(string): - parts.append({"type": "text", "content": string[cursor:]}) - - return parts - - -def image_text_replacement(config) -> str: - if config.model_type == "idefics2": - image_seq_len = 64 - image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" - return image_str - elif config.model_type == "llava_next": - return "" - elif config.model_type == "paligemma": - return "" - elif config.model_type == "mllama": - return "<|image|>" - else: - raise RuntimeError(f"Unknown config {config.model_type} for multimodal") - - -def image_text_replacement_fixup(config, text: str) -> str: - if config.model_type == "idefics2": - return text.replace( - f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN - ) - return text - - -def get_unpadded_features( - original_height: int, - original_width: int, - npatches: int, - num_patch_height: int, - num_patch_width: int, -) -> Tuple[int, int]: - current_height = npatches * num_patch_height - current_width = npatches * num_patch_width - - aspect_ratio: float = original_width / original_height - current_aspect_ratio: float = current_width / current_height - - if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width - padding = (current_height - new_height) // 2 - current_height = current_height - (2 * padding) - else: - new_width = (original_width * current_height) // original_height - padding = (current_width - new_width) // 2 - current_width = current_width - (2 * padding) - - unpadded_features = current_height * current_width - newline_features = current_height - return (unpadded_features, newline_features) - - -def get_number_of_features(height: int, width: int, config) -> int: - # From config - # Hardcoded for CLIP for now - # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] - image_grid_pinpoints = config.image_grid_pinpoints - image_size = config.vision_config.image_size - patch_size = config.vision_config.patch_size - - assert image_size % patch_size == 0 - - npatches = image_size // patch_size - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - [height, width], - image_grid_pinpoints, - image_size, - ) - - unpadded_features, newline_features = get_unpadded_features( - height, width, npatches, num_patch_height, num_patch_width - ) - # The base patch covers the entire image - base_features = npatches**2 - return unpadded_features + newline_features + base_features - - -class VlmCausalLMBatch(CausalLMBatch): - pixel_values: Optional[List[torch.Tensor]] - pixel_attention_mask: Optional[List[torch.Tensor]] - image_sizes: Optional[List[Tuple[int, int]]] - aspect_ratio_ids: Optional[torch.Tensor] = None - aspect_ratio_mask: Optional[torch.Tensor] = None - cross_attention_mask: Optional[torch.Tensor] = None - prefilling: bool = True - token_idx: torch.Tensor = None - - def __init__( - self, - batch_id, - requests, - input_ids, - attention_mask, - position_ids, - past_key_values, - merged_kv_cache, - next_token_chooser, - top_n_tokens, - top_n_tokens_tensor, - input_length, - pixel_values: Optional[List[torch.Tensor]] = None, - pixel_attention_mask: Optional[List[torch.Tensor]] = None, - image_sizes: Optional[List[Tuple[int, int]]] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - prefilling: Optional[bool] = True, - ): - super().__init__( - batch_id=batch_id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=merged_kv_cache, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - - self.pixel_values = pixel_values - self.pixel_attention_mask = pixel_attention_mask - self.image_sizes = image_sizes - self.aspect_ratio_ids = aspect_ratio_ids - self.aspect_ratio_mask = aspect_ratio_mask - self.cross_attention_mask = cross_attention_mask - self.prefilling = prefilling - - @property - def token_idx(self): - if self.prefilling: - # no right padding for prefill - token_idx_scalar = self.attention_mask.shape[-1] - 1 - return torch.tensor(token_idx_scalar).to(self.attention_mask.device) - else: - token_idx_scalar = self.attention_mask.shape[-1] - self.right_padding - return torch.tensor(token_idx_scalar).to(self.attention_mask.device) - - def padding_process(self, pad_id: int): - # self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1) - right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1] - self.input_ids = torch.nn.functional.pad( - self.input_ids, (0, right_padding), value=pad_id - ) - self.attention_mask = torch.nn.functional.pad( - self.attention_mask, (0, right_padding), value=0 - ) - # if self.position_ids is not None: - # self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1 - if self.cross_attention_mask is not None: - self.cross_attention_mask = torch.nn.functional.pad( - self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0 - ) - if self.past is not None: - past_key_values_list = list(self.past_key_values) - for layer_id in range(len(self.past)): - past_key_value_list = list(self.past_key_values[layer_id]) - if layer_id not in CROSS_ATTENTION_LAYERS: - past_key_value_list[0] = torch.nn.functional.pad( - self.past_key_values[layer_id][0], - (0, 0, 0, right_padding), - value=0, - ) - past_key_value_list[1] = torch.nn.functional.pad( - self.past_key_values[layer_id][1], - (0, 0, 0, right_padding), - value=0, - ) - past_key_values_list[layer_id] = tuple(past_key_value_list) - self.past_key_values = tuple(past_key_values_list) - - self.prefilling = False - self.input_length = self.input_length - - @classmethod - def from_tokenized( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - batch_tokenized_inputs, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - - dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") - requests = [ - CausalLMRequest.from_pb(idx, req, tokenizer) - for idx, req in enumerate(pb.requests) - ] - - max_input_length = max(r.data.truncate for r in requests) - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - # TODO: Add support for sparse batches - top_n_tokens = [r.top_n_tokens for r in pb.requests] - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - tokenized_inputs = batch_tokenized_inputs - input_len = tokenized_inputs["input_ids"].shape[1] - - bucket_size = max_input_length - left_padding = max_input_length - input_len - if is_warmup is False: - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) - bucket_size = rounded_seq_len - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None) - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - if cross_attention_mask is not None: - cross_attention_mask = torch.nn.functional.pad( - cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - - # New input length after left padding - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - cross_attention_mask = ( - cross_attention_mask.to(device) - if cross_attention_mask is not None - else None - ) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() - - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - cross_attention_mask=cross_attention_mask, - ) - - @classmethod - def batch_tokenized_inputs( - cls, - requests: Iterable[generate_pb2.Request], - tokenizer, - processor, - config, - is_warmup, - ): - image_inputs = {} - texts = [] - images = [] - batch_tokenized_inputs = {} - - for i, r in enumerate(requests): - # Each input is encoded into a list, where each element of this input list is either a string or a URL - curr_text = "" - curr_image = None - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - curr_text += chunk.text - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO unsure about BOS - curr_image = image - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - if image_text_replacement(config) not in curr_text: - if "" in curr_text: - curr_text = curr_text.replace( - "", image_text_replacement(config) - ) - else: - curr_text = image_text_replacement(config) + curr_text - - texts.append(curr_text) - if curr_image is not None: - if config.model_type == "mllama": - images.append([curr_image]) - else: - images.append(curr_image) - - if is_warmup is True: - images += [images[0]] * (len(texts) - len(images)) - - missing_inputs = 0 - dummy_images = None - if is_warmup is False: - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - missing_inputs = new_bs - len(requests) - if missing_inputs > 0: - dummy_inputs = [] - if len(texts) > 0: - dummy_inputs = [texts[0]] * missing_inputs - dummy_images = [images[0]] * missing_inputs - texts += dummy_inputs - images += dummy_images - - processor_output = processor( - images, - texts, - truncation=True, - max_length=r.truncate, - add_special_tokens=r.add_special_tokens, - return_tensors="pt", - padding_side="left", - padding="longest", - ) - if "input_ids" in processor_output: - batch_tokenized_inputs.update({"input_ids": processor_output["input_ids"]}) - if "attention_mask" in processor_output: - batch_tokenized_inputs.update( - {"attention_mask": processor_output["attention_mask"]} - ) - if "cross_attention_mask" in processor_output: - batch_tokenized_inputs.update( - {"cross_attention_mask": processor_output["cross_attention_mask"]} - ) - if "pixel_values" in processor_output: - image_inputs.update({"pixel_values": processor_output["pixel_values"]}) - if "pixel_attention_mask" in processor_output: - image_inputs.update( - {"pixel_attention_mask": processor_output["pixel_attention_mask"]} - ) - if "aspect_ratio_ids" in processor_output: - image_inputs.update( - {"aspect_ratio_ids": processor_output["aspect_ratio_ids"]} - ) - if "aspect_ratio_mask" in processor_output: - image_inputs.update( - {"aspect_ratio_mask": processor_output["aspect_ratio_mask"]} - ) - if "image_sizes" in processor_output: - image_inputs.update({"image_sizes": processor_output["image_sizes"]}) - - return batch_tokenized_inputs, image_inputs - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor, - config, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config, is_warmup - ) - batch = cls.from_tokenized( - pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup - ) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "aspect_ratio_ids" in image_inputs: - batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to( - device=device - ) - else: - batch.aspect_ratio_ids = None - if "aspect_ratio_mask" in image_inputs: - batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( - device=device - ) - else: - batch.aspect_ratio_mask = None - else: - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - batch.aspect_ratio_ids = None - batch.aspect_ratio_mask = None - batch.cross_attention_mask = None - - return batch - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, - batches: List["CausalLMBatch"], - pad_token_id: int = 0, - is_warmup: bool = False, - ) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id, is_warmup) - - @classmethod - def recombine( - cls, - batches: List["VlmCausalLMBatch"], - pad_token_id: int, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - # Used for padding - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - if not is_warmup: - new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) - - if len(batches) > 1: - scenario = "CONCAT" - elif batches[0].prefilling: - scenario = "SHIFT" - else: - return batches[0] - - dbg_trace( - scenario, - f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}", - ) - - if scenario == "SHIFT": - batch = batches[0] - batch.padding_process(pad_token_id) - return batch - - total_batch_size = 0 - max_input_length = 0 - for i, batch in enumerate(batches): - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.input_length) - # Batch attributes - requests = [] - input_lengths = [] - top_n_tokens = [] - parameters = [] - fsm_grammar_states = [] - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - top_n_tokens_tensor = None - cross_attention_mask = None - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - keep_indices = [] - for req in batch.requests: - keep_indices.append(req.idx) - - requests.extend(batch.requests) - parameters.extend([r.data.parameters for r in batch.requests]) - fsm_grammar_states.extend( - [batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices] - ) - input_lengths.extend([batch.input_length]) - top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices]) - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS)) - # # Copy to correct indices - - left_offset = max_input_length - batch.input_length - right_padding = MAX_TOTAL_TOKENS - max_input_length - input_ids[start_index:end_index, left_offset:-right_padding] = ( - batch.input_ids[keep_indices, : batch.input_length] - ) - - # Create padded tensor - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - new_bs, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[ - keep_indices - ] - - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (new_bs, MAX_TOTAL_TOKENS), - ) - - attention_mask[ - start_index:end_index, - left_offset:-right_padding, - ] = batch.attention_mask[ - keep_indices, - : batch.input_length, - ] - - if batch.cross_attention_mask is not None: - cross_attention_mask_shape = list(batch.cross_attention_mask.shape) - cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS - cross_attention_mask_shape[0] = new_bs - cross_attention_mask_shape = torch.Size(cross_attention_mask_shape) - if cross_attention_mask is None: - cross_attention_mask = batch.cross_attention_mask.new_zeros( - cross_attention_mask_shape, - ) - cross_attention_mask[ - start_index:end_index, - left_offset:-right_padding, - ] = batch.cross_attention_mask[ - keep_indices, - : batch.input_length, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((new_bs, 1)) - position_ids[start_index:end_index] = batch.position_ids[keep_indices, :] - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values, tuple): - batch.past_key_values = [ - [t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:]) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - past_key_values = [] - for layer_id in range(len(batches[0].past_key_values)): - if layer_id in CROSS_ATTENTION_LAYERS: - padded_past_keys_shape = list( - batches[0].past_key_values[layer_id][0].shape - ) - padded_past_keys_shape[0] = new_bs - padded_past_keys_shape = torch.Size(padded_past_keys_shape) - else: - padded_past_keys_shape = ( - new_bs, - num_heads, - MAX_TOTAL_TOKENS, - head_dim, - ) - - padded_past_keys = first_past_kvs[layer_id][0].new_zeros( - padded_past_keys_shape - ) - padded_past_values = first_past_kvs[layer_id][1].new_zeros( - padded_past_keys_shape - ) - start_index = 0 - for batch in batches: - keep_indices = [] - for req in batch.requests: - keep_indices.append(req.idx) - - left_offset = max_input_length - batch.input_length - right_padding = MAX_TOTAL_TOKENS - max_input_length - past_keys = batch.past_key_values[layer_id][0] - past_values = batch.past_key_values[layer_id][1] - # Clear reference to the original tensor - batch.past_key_values[layer_id] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - if layer_id in CROSS_ATTENTION_LAYERS: - padded_past_keys[start_index:end_index, :, :, :] = past_keys[ - keep_indices, :, :, : - ] - padded_past_values[start_index:end_index, :, :, :] = past_values[ - keep_indices, :, :, : - ] - - else: - padded_past_keys[ - start_index:end_index, :, left_offset:-right_padding, : - ] = past_keys[keep_indices, :, : batch.input_length, :] - padded_past_values[ - start_index:end_index, :, left_offset:-right_padding, : - ] = past_values[keep_indices, :, : batch.input_length, :] - - start_index = end_index - - past_key_values.append(tuple([padded_past_keys, padded_past_values])) - past_key_values = tuple(past_key_values) - - batch_id = batches[0].batch_id - top_n_tokens.extend([-1] * (new_bs - total_batch_size)) - fsm_grammar_states.extend([-1] * (new_bs - total_batch_size)) - - for idx, req in enumerate(requests): - req.idx = idx - - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[0].next_token_chooser.dtype, - batches[0].next_token_chooser.device, - batches[0].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - pixel_values=None, - pixel_attention_mask=None, - image_sizes=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=cross_attention_mask, - prefilling=False, - ) - - -class VlmCausalLM(Model): - def __init__( - self, - model_class, - model_id: str, - *, - processor_class=AutoProcessor, - processor_kwargs=None, - batch_class=VlmCausalLMBatch, - revision, - quantize: Optional[str] = None, - dtype, - trust_remote_code: bool, - **kwargs, - ): - adapt_transformers_to_gaudi() - if processor_kwargs is None: - processor_kwargs = {} - self.processor = processor_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - **processor_kwargs, - ) - self.batch_class = batch_class - self.prev_bs = 0 - self.quantize = quantize - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - make_tokenizer_optional(tokenizer) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - # Get weight files - weight_files(model_id, revision=revision, extension=".safetensors") - - if world_size > 1: - os.environ.setdefault( - "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" - ) - model = self.get_deepspeed_model(model_class, model_id, dtype, revision) - model = hq_env.prepare_model_for_quantization(model) - else: - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained(model_id) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = model_class.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs, - ) - model = hq_env.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = ( - os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace("TORCH COMPILE", "Torch compiling of model") - model.model = torch.compile( - model.model, - backend="hpu_backend", - options={"keep_input_mutations": True}, - ) - - model = hq_env.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - self.kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in ["llava_next"]: - self.kwargs["attn_softmax_bf16"] = True - self.kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": - self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": - self.kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - if model.config.model_type == "mllama": - global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS - CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers - BASE_IMAGE_TOKENS = 0 - - super(VlmCausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = ( - int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_steps = ( - int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes, - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - @property - def batch_type(self) -> Type[VlmCausalLMBatch]: - return self.batch_class - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) - - def get_deepspeed_model( - self, - model_class, - model_id: str, - dtype: torch.dtype, - revision: Optional[str] = None, - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = {"revision": revision} - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( - world_size, rank, local_rank - ) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = model_class.from_config(config, torch_dtype=dtype) - else: - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = model_class.from_pretrained( - model_id, torch_dtype=dtype, **model_kwargs - ) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - ds_inference_kwargs["injection_policy"] = get_ds_injection_policy( - model.language_model.config - ) - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [ - str(f) - for f in weight_files( - model_id, revision=revision, extension=".safetensors" - ) - ] - data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, checkpoints_json) - checkpoints_json.flush() - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return {"type": rope_scaling, "factor": float(rope_factor)} - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode( - all_input_ids[read_offset:], skip_special_tokens=False - ) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) - - def forward( - self, - batch: VlmCausalLMBatch, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": batch.input_ids, - "attention_mask": batch.attention_mask, - "past_key_values": batch.past_key_values, - "token_idx": batch.token_idx, - "pixel_values": batch.pixel_values, - } - - if self.model.config.model_type == "mllama": - kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids - kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask - kwargs["cross_attention_mask"] = batch.cross_attention_mask - else: - kwargs["image_sizes"] = batch.image_sizes - - hpu_kwargs = {} - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama": - hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 - - if self.has_position_ids: - kwargs["position_ids"] = batch.position_ids - if bypass_hpu_graph is not None: - hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - model_inputs = self.model.prepare_inputs_for_generation(**kwargs) - - if batch.past_key_values is not None: - return self.model.forward(**model_inputs, **hpu_kwargs) - else: - outputs = self.model.forward(**model_inputs, **hpu_kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: list[VlmCausalLMBatch], is_warmup: bool = False - ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]: - - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = ( - batch.attention_mask.shape[-1] - batch.right_padding - ) - token_idx = torch.tensor(token_idx_scalar).to(self.device) - - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, - logits[:, input_length - 1 : input_length, :].squeeze(-2), - self.speculate, - ) - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append( - { - "next_token_ids": next_token_ids, - "next_token_logprobs": next_token_logprobs, - } - ) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append( - { - "req": req, - "prev_req_idx": req.idx, - "batch_id": batch_id, - "seed": batch.next_token_chooser.seeds[req_idx], - "do_sample": batch.next_token_chooser.do_sample[req_idx], - "top_n_tokens": batch.top_n_tokens[req_idx], - "top_token_ids": batch_top_token_ids[req_idx], - "top_token_logprobs": batch_top_token_logprobs[req_idx], - "grammar_state": batch.next_token_chooser.fsm_grammar_states[ - req.idx - ], - } - ) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # add cross-attn mask for new token - if batch.cross_attention_mask is not None: - cross_attention_mask_prev = batch.cross_attention_mask - if token_idx is not None: - mask = cross_attention_mask_prev[ - :, token_idx - 2 : token_idx - 1, ... - ] - cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) - batch.cross_attention_mask = cross_attention_mask_prev - - # Adjust lengths - batch.input_length += 1 - # Update position_ids - if prefill: - batch.position_ids = ( - torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - ) - else: - batch.position_ids += 1 - # Update past key values - if prefill: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate( - batches, self.tokenizer.pad_token_id, is_warmup - ) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = self.batch_type.recombine( - [batch], self.tokenizer.pad_token_id, is_warmup - ) - - scenario = "PREFILL" if prefill else "GENERATE" - if ( - self.enable_hpu_graph - and self.limit_hpu_graph - and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - != self.prev_bs - ): - self.model.clear_cache() - self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - dbg_trace( - scenario, - f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", - ) - # assert batch.right_padding > 0, 'No more room for next token!' - - # Execute batch - if prefill: - # no right padding for prefill - # token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - # token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) - batch.logits = self.forward( - batch, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.aspect_ratio_ids is not None: - batch.aspect_ratio_ids = None - if batch.aspect_ratio_mask is not None: - batch.aspect_ratio_mask = None - - htorch.core.mark_step() - - start_decode = time.time_ns() - - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch["next_token_logprobs"] = prev_batch[ - "next_token_logprobs" - ].tolist() - prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data["req"] - i = req_data["prev_req_idx"] - prev_batch_id = req_data["batch_id"] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] - next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data["do_sample"] - seed = req_data["seed"] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data["top_n_tokens"] - top_token_ids = req_data["top_token_ids"] - top_token_logprobs = req_data["top_token_logprobs"] - grammar_state = req_data["grammar_state"] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if ( - is_tokenizer_transparent(self.tokenizer) - and len(stopping_criteria.stop_sequence_criterias) == 0 - ): - next_token_text = "" - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[ - new_input_length - - stopping_criteria.current_tokens : new_input_length, - 0, - ] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if ( - self.step - > self.profiling_wait_steps - + self.profiling_warmup_steps - + self.profiling_steps - ): - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def batch_from_pb(self, batch, is_warmup): - return self.batch_type.from_pb_processor( - batch, - self.tokenizer, - self.processor, - self.model.config, - self.dtype, - self.device, - is_warmup, - ) - - def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): - batch = copy.deepcopy(request.batch) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_from_pb(batch, is_warmup) - - def warmup( - self, request: generate_pb2.WarmupRequest - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - global MAX_TOTAL_TOKENS - MAX_TOTAL_TOKENS = request.max_total_tokens - batch = self.batch_from_pb(request.batch, is_warmup=True) - max_input_tokens = request.max_input_tokens - max_prefill_batch_size = batch.input_ids.shape[0] - max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") - if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) - else: - raise ValueError("MAX_BATCH_SIZE is not set") - - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) - except Exception: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - PREFILL_WARMUP_BATCH_SIZE_LIST = [] - batch_size = 1 - while batch_size <= max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - - if self.model.config.model_type == "mllama": - seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF - else: - seq_len = BASE_IMAGE_TOKENS - - PREFILL_WARMUP_SEQLEN_LIST = [] - i = 0 - while seq_len <= max_input_tokens: - PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i) - i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_tokens: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_tokens) - - # Prefill and decode warmup - DECODE_WARMUP_BATCH_SIZE_LIST = [] - prefill_batch = None - decode_batch = None - try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST: - for seq_len in PREFILL_WARMUP_SEQLEN_LIST: - batch = self.generate_warmup_batch( - request, seq_len, batch_size, is_warmup=True - ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) - assert prefill_batch is not None - _, decode_batch, _ = self.generate_token( - [prefill_batch], is_warmup=True - ) - - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - - except Exception: - raise RuntimeError( - f"Not enough memory to handle following prefill and decode warmup." - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats} " - ) - - max_decode_batch_size = MAX_BATCH_SIZE - batch_size = max_prefill_batch_size * 2 - # Decode warmup with bigger batch_size - try: - if ( - DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size - and batch_size <= max_decode_batch_size - ): - batches = [] - while batch_size <= max_decode_batch_size: - for i in range(int(batch_size / max_prefill_batch_size)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0] - 1, - max_prefill_batch_size, - is_warmup=True, - ) - _, prefill_batch, _ = self.generate_token( - [batch], is_warmup=True - ) - batches.append(prefill_batch) - - _, decode_batch, _ = self.generate_token(batches, is_warmup=True) - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - batches.clear() - - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: - max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2 - batch_size = max_decode_batch_size - for i in range(int(max_decode_batch_size / 2)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0] - 1, - 2, - is_warmup=True, - ) - _, prefill_batch, _ = self.generate_token( - [batch], is_warmup=True - ) - batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup=True) - DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) - - except Exception: - raise RuntimeError( - f"Not enough memory to handle batch_size({batch_size}) decode warmup." - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"max_decode_batch_size is {max_decode_batch_size}" - f"You need to decrease env `MAX_BATCH_SIZE` or '--max_batch_size'" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats}" - ) - - max_supported_total_tokens = MAX_BATCH_SIZE * MAX_TOTAL_TOKENS - max_input_tokens = max_input_tokens - max_total_tokens = MAX_TOTAL_TOKENS - - return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh index d787ea8e..a5c3f5e1 100644 --- a/backends/gaudi/tgi-entrypoint.sh +++ b/backends/gaudi/tgi-entrypoint.sh @@ -7,13 +7,5 @@ if [[ "$*" == *"--sharded true"* ]]; then echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding' export PT_HPU_ENABLE_LAZY_COLLECTIVES=1 fi -# Check if ATTENTION environment variable is set to paged -if [[ "$ATTENTION" == "paged" ]]; then - # Check if Llama-4 is in the command line arguments - if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then - echo 'ATTENTION=paged and Llama-4 or Qwen3 detected' - pip install git+https://github.com/huggingface/transformers.git@29338949 - fi -fi text-generation-launcher $@ diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d9056e41..cd4ee290 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -27,10 +27,6 @@ impl Env { docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } - - pub fn should_start_a_single_hpu_shard(&self) -> bool { - self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") - } } impl fmt::Display for Env { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ee80eb00..c727623c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1590,11 +1590,6 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { - tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); - break; - } - let model_id = args.model_id.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); @@ -1670,10 +1665,6 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().should_start_a_single_hpu_shard() { - tracing::info!("HPU detected, shard is ready"); - break; - } } Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100));