Gaudi: Sync TGI with the latest changes from the TGI-Gaudi fork (#3117)

feat(gaudi): add all the changes from tgi-gaudi fork up to PR #289
This commit is contained in:
Baptiste Colle 2025-03-18 09:45:52 +01:00 committed by GitHub
parent 095775e05c
commit 8c2c348f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 914 additions and 1243 deletions

View File

@ -1,6 +1,6 @@
# Those arguments are required to build the image
ARG HABANA_VERSION=1.19.0
ARG PYTORCH_VERSION=2.5.1
ARG HABANA_VERSION=1.20.0
ARG PYTORCH_VERSION=2.6.0
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
@ -92,7 +92,6 @@ RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install outlines~=0.0.34 && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir

View File

@ -2,8 +2,8 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := "${mkfile_dir}/../.."
HABANA_VERSION := 1.19.0
PYTORCH_VERSION := 2.5.1
HABANA_VERSION := 1.20.0
PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install

View File

@ -22,7 +22,7 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
peft = "^0.10"
optimum-habana = "1.15.0"
optimum-habana = "1.16.0"
transformers = "4.45.2"
numpy = "1.26.4"
accelerate = "0.33.0"

View File

@ -46,7 +46,7 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
@ -87,3 +87,18 @@ wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13"
pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -17,16 +17,14 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration,
# )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
@ -39,6 +37,7 @@ from text_generation_server.adapters.lora import LoraWeights
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
@ -55,6 +54,8 @@ def get_model(
max_input_tokens: int,
) -> Model:
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if speculate is not None:
set_speculate(speculate)
@ -199,6 +200,17 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "mllama":
return VlmCausalLM(
model_class=MllamaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -704,6 +704,9 @@ class CausalLM(Model):
htorch.core.hpu_set_env()
if world_size > 1:
os.environ.setdefault(
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
)
model = self.get_deepspeed_model(model_id, dtype, revision)
model = hq_env.prepare_model_for_quantization(model)
else:

View File

@ -14,10 +14,11 @@
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional
from typing import List, Optional, Union
import torch
import torch.utils.checkpoint
import numpy as np
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
@ -49,6 +50,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
@ -128,6 +169,76 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def prepare_inputs_for_generation(
self,
input_ids,
@ -184,35 +295,12 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
batch_size, num_patches, num_channels, height, width = (
pixel_values.shape
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
reshaped_pixel_values = pixel_values.view(
batch_size * num_patches, num_channels, height, width
)
image_features = self.vision_tower(
reshaped_pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
selected_image_feature = image_features.hidden_states[
vision_feature_layer
]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
@ -266,13 +354,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
image_features = torch.cat(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
inputs_embeds, image_features, input_ids
)
self.image_offset = (
image_features.shape[1] - 1
) # image_token has occupied 1 token position.
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
@ -282,12 +367,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(

View File

@ -291,6 +291,8 @@ The following table contains the environment variables that can be used to confi
Contributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md).
**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder.
### Building the Docker Image from Source
To build the Docker image from source: