From 48df6183e8e15c433ae71b4885ddfcb7b7f603b1 Mon Sep 17 00:00:00 2001 From: baptiste Date: Mon, 17 Mar 2025 11:01:53 +0000 Subject: [PATCH] feat(gaudi): add all the changes from tgi-gaudi fork up to PR #289 --- Dockerfile_gaudi | 5 +- backends/gaudi/Makefile | 4 +- backends/gaudi/server/pyproject.toml | 2 +- backends/gaudi/server/requirements.txt | 17 +- .../text_generation_server/models/__init__.py | 22 +- .../models/causal_lm.py | 3 + .../models/custom_modeling/llava_next.py | 153 ++- .../models/custom_modeling/mllama.py | 1211 ++++------------- .../models/vlm_causal_lm.py | 738 ++++++---- docs/source/backends/gaudi.mdx | 2 + 10 files changed, 914 insertions(+), 1243 deletions(-) diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index c814e912..9009f95b 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -1,6 +1,6 @@ # Those arguments are required to build the image -ARG HABANA_VERSION=1.19.0 -ARG PYTORCH_VERSION=2.5.1 +ARG HABANA_VERSION=1.20.0 +ARG PYTORCH_VERSION=2.6.0 # Rust builder FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef @@ -92,7 +92,6 @@ RUN cd server && \ make gen-server && \ pip install --no-deps -r requirements.txt && \ bash ./dill-0.3.8-patch.sh && \ - pip install outlines~=0.0.34 && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index ce3be25d..6e38c19e 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -2,8 +2,8 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) root_dir := "${mkfile_dir}/../.." -HABANA_VERSION := 1.19.0 -PYTORCH_VERSION := 2.5.1 +HABANA_VERSION := 1.20.0 +PYTORCH_VERSION := 2.6.0 .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml index c61ac030..b38f4562 100644 --- a/backends/gaudi/server/pyproject.toml +++ b/backends/gaudi/server/pyproject.toml @@ -22,7 +22,7 @@ opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" peft = "^0.10" -optimum-habana = "1.15.0" +optimum-habana = "1.16.0" transformers = "4.45.2" numpy = "1.26.4" accelerate = "0.33.0" diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt index 07414490..b71d29d9 100644 --- a/backends/gaudi/server/requirements.txt +++ b/backends/gaudi/server/requirements.txt @@ -46,7 +46,7 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +optimum-habana==1.16.0 ; python_version >= "3.9" and python_version < "3.13" optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13" @@ -87,3 +87,18 @@ wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13" yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13" zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" +outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13" +interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13" +lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13" +cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13" +diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13" +numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13" +llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13" +jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13" +annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13" +jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13" +nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13" +pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13" +pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13" +referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13" +rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 651b71ec..346016c2 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -17,16 +17,14 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.vlm_causal_lm import VlmCausalLM - -# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM +from text_generation_server.models.custom_modeling.mllama import ( + MllamaForConditionalGeneration, +) from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch -# from text_generation_server.models.custom_modeling.mllama import ( -# MllamaForConditionalGeneration, -# ) from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, @@ -39,6 +37,7 @@ from text_generation_server.adapters.lora import LoraWeights from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0)) # Disable gradients torch.set_grad_enabled(False) @@ -55,6 +54,8 @@ def get_model( max_input_tokens: int, ) -> Model: adapt_transformers_to_gaudi() + if SDP_ON_BF16 == 1: + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if speculate is not None: set_speculate(speculate) @@ -199,6 +200,17 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_type == "mllama": + return VlmCausalLM( + model_class=MllamaForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=None, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 8fda0517..776c109f 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -704,6 +704,9 @@ class CausalLM(Model): htorch.core.hpu_set_env() if world_size > 1: + os.environ.setdefault( + "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" + ) model = self.get_deepspeed_model(model_id, dtype, revision) model = hq_env.prepare_model_for_quantization(model) else: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index f98dab91..70449f6b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,10 +14,11 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional +from typing import List, Optional, Union import torch import torch.utils.checkpoint +import numpy as np from transformers.models.llava_next.modeling_llava_next import ( unpad_image, @@ -49,6 +50,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size +# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type {type(image_size)} with value {image_size}" + ) + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): def _merge_input_ids_with_image_features( @@ -128,6 +169,76 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return outputs + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] + for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError( + f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions" + ) + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_features.hidden_states[vision_feature_layer] + else: + hs_pool = [ + image_features.hidden_states[layer_idx] + for layer_idx in vision_feature_layer + ] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + def prepare_inputs_for_generation( self, input_ids, @@ -184,35 +295,12 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): # 1. Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images - batch_size, num_patches, num_channels, height, width = ( - pixel_values.shape + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, ) - reshaped_pixel_values = pixel_values.view( - batch_size * num_patches, num_channels, height, width - ) - image_features = self.vision_tower( - reshaped_pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - selected_image_feature = image_features.hidden_states[ - vision_feature_layer - ] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" height = width = ( @@ -266,13 +354,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): (image_feature, self.image_newline[None]), dim=0 ) new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) + image_features = torch.cat(new_image_features, dim=0) inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds, image_features, input_ids ) - self.image_offset = ( - image_features.shape[1] - 1 - ) # image_token has occupied 1 token position. # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache elif past_key_values is not None: @@ -282,12 +367,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where( first_layer_past_key_value.float().sum(-2) == 0 ) - # Get the target length past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py index 73536bd6..6ba0ffff 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py @@ -14,982 +14,279 @@ # limitations under the License. """PyTorch Mllama model.""" -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch import torch.utils.checkpoint -from torch import nn -import flash_attn_2_cuda -from transformers.activations import ACT2FN -import torch.nn.functional as F - -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - FastLinear, -) -from text_generation_server.layers.attention import ( - Seqlen, -) -from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, +from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration +from optimum.habana.transformers.models.mllama.modeling_mllama import ( + _prepare_cross_attention_mask, ) +from transformers.modeling_outputs import CausalLMOutputWithPast -def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, -) -> torch.Tensor: - # Expand aspect ratio mask to target_length - batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) - attention_mask = attention_mask.repeat(1, 1, target_length, 1) - - # Mask padding patches - pad_patches = target_length - num_patches - attention_mask[:, :, -pad_patches:] = 0 - - # Invert the mask (0 -> 1, 1 -> 0) - attention_mask = 1 - attention_mask - - # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) - attention_mask = attention_mask.reshape( - batch_size, max_num_tiles * target_length, 1 - ) - attention_mask = ( - attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min - ) - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - - -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = ( - causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) - - return causal_mask - - -def _prepare_cross_attention_mask( - cross_attention_mask: torch.Tensor, - num_vision_tokens: int, - dtype: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - # reshape so it can be used by attn module - batch_size, text_total_length, *_ = cross_attention_mask.shape - cross_attention_mask = cross_attention_mask.repeat_interleave( - num_vision_tokens, dim=3 - ) - cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) - cross_attention_mask = cross_attention_mask.unsqueeze(1) - - # invert the mask - inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min - ) - - # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value) - .any(dim=-1) - .type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - return cross_attention_mask, full_text_row_masked_out_mask - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision -class MllamaVisionMLP(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = TensorParallelColumnLinear.load( - prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True - ) - self.fc2 = TensorParallelRowLinear.load( - prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - - self.embed_dim = config.hidden_size - self.head_dim = config.hidden_size // config.attention_heads - self.num_heads = config.attention_heads // weights.process_group.size() - - self.qkv_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - self.o_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ) +class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration): def forward( self, - hidden_state: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv = self.qkv_proj(hidden_state) - query, key, value = qkv.split( - [ - self.head_dim * self.num_heads, - self.head_dim * self.num_heads, - self.head_dim * self.num_heads, - ], - dim=2, - ) - - batch_size, q_seq_len, _ = query.shape - _, kv_seq_len, _ = key.shape - - query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) - key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - - output = self.o_proj(attn_output) - return output - - -class MllamaVisionEncoderLayer(nn.Module): - def __init__(self, *, prefix, config, weights, is_gated: bool): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_attention_heads = config.attention_heads - self.is_gated = is_gated - self.intermediate_size = config.intermediate_size - - self.self_attn = MllamaVisionSdpaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights - ) - self.mlp = MllamaVisionMLP( - prefix=f"{prefix}.mlp", config=config, weights=weights - ) - - self.input_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 - ) - self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 - ) - - # there used to be an if else here, no code path - if is_gated: - self.gate_attn = nn.Parameter( - weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False - ) - self.gate_ffn = nn.Parameter( - weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - # Self Attention - residual = hidden_state - hidden_state = self.input_layernorm(hidden_state) - hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) - gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() - hidden_state = residual + gate_attn * hidden_state - - # Feed forward - residual = hidden_state - hidden_state = self.post_attention_layernorm(hidden_state) - hidden_state = self.mlp(hidden_state) - gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() - hidden_state = residual + gate_ffn * hidden_state - return hidden_state - - -class MllamaVisionEncoder(nn.Module): - def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): - super().__init__() - self.config = config - self.layers = [ - MllamaVisionEncoderLayer( - prefix=f"{prefix}.layers.{i}", - config=config, - weights=weights, - is_gated=is_gated, - ) - for i in range(num_layers) - ] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - encoder_states = [hidden_states] - for encoder_layer in self.layers: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - ) - - hidden_states = layer_outputs - encoder_states.append(hidden_states) - - return hidden_states, encoder_states - - -class MllamaPrecomputedAspectRatioEmbedding(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.max_aspect_ratio_id = config.max_aspect_ratio_id - - self.embedding = TensorParallelEmbedding( - prefix=f"{prefix}.embedding", weights=weights - ) - self.gate = nn.Parameter( - weights.get_tensor(f"{prefix}.gate"), requires_grad=False - ) - - def forward( - self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor - ) -> torch.Tensor: - embeddings = self.embedding(aspect_ratio_ids) - embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) - - # Always gated. - embeddings = embeddings * self.gate.tanh() - - hidden_state = hidden_state + embeddings - return hidden_state - - -class MllamaPrecomputedPositionEmbedding(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 - self.hidden_size = config.hidden_size - self.scale = config.hidden_size**-0.5 - - self.gate = nn.Parameter( - weights.get_tensor(f"{prefix}.gate"), requires_grad=False - ) - - # position embedding - embedding = nn.Parameter( - weights.get_tensor(f"{prefix}.embedding"), requires_grad=False - ) - self.gated_position_embedding = (1 - self.gate.tanh()) * embedding - self.tile_embedding = TensorParallelEmbedding( - prefix=f"{prefix}.tile_embedding", weights=weights - ) - - def forward( - self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor - ) -> torch.Tensor: - # position embeddings - hidden_state = hidden_state + self.gated_position_embedding.view( - 1, 1, self.num_patches, self.hidden_size - ) - - # precomputed tile position embeddings - tile_position_embedding = self.tile_embedding(aspect_ratio_ids) - batch_size = hidden_state.shape[0] - tile_position_embedding = tile_position_embedding.reshape( - batch_size, self.max_num_tiles, self.num_patches, self.hidden_size - ) - gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding - hidden_state = hidden_state + gated_tile_position_embedding - - return hidden_state - - -class MllamaVisionModel(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.image_size = config.image_size - self.patch_size = config.patch_size - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.num_channels = config.num_channels - self.intermediate_layers_indices = config.intermediate_layers_indices - - self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 - self.scale = config.hidden_size**-0.5 - self.dtype = weights.dtype - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.hidden_size, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - bias=False, - ) - self.patch_embedding.weight = nn.Parameter( - weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False - ) - - self.class_embedding = nn.Parameter( - weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False - ) - - self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( - prefix=f"{prefix}.gated_positional_embedding", - config=config, - weights=weights, - ) - - self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( - prefix=f"{prefix}.pre_tile_positional_embedding", - config=config, - weights=weights, - ) - self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( - prefix=f"{prefix}.post_tile_positional_embedding", - config=config, - weights=weights, - ) - - ## layer norms - self.layernorm_pre = nn.LayerNorm.load( - prefix=f"{prefix}.layernorm_pre", - weights=weights, - # torch default - eps=1e-05, - ) - self.layernorm_post = nn.LayerNorm.load( - prefix=f"{prefix}.layernorm_post", - weights=weights, - # torch default - eps=1e-05, - ) - - ## encoders - self.transformer = MllamaVisionEncoder( - prefix=f"{prefix}.transformer", - config=config, - weights=weights, - is_gated=False, - num_layers=config.num_hidden_layers, - ) - self.global_transformer = MllamaVisionEncoder( - prefix=f"{prefix}.global_transformer", - config=config, - weights=weights, - is_gated=True, - num_layers=config.num_global_layers, - ) - - def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: - batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) - hidden_state = torch.cat([class_embedding, hidden_state], dim=1) - return hidden_state - - def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) - - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, height, width - ) - aspect_ratio_ids = aspect_ratio_ids.reshape( - batch_size * num_concurrent_media, -1 - ) - - # patch embedding - patch_embeds = self.patch_embedding(pixel_values) - hidden_state = patch_embeds.flatten(2).transpose(1, 2) - - # tile embeddings - _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, -1, dim - ) - hidden_state = self.pre_tile_positional_embedding( - hidden_state, aspect_ratio_ids - ) - - # apply cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim - ) - hidden_state = self.apply_class_embedding(hidden_state) - num_patches += 1 - - # apply position embeddings - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches, dim - ) - hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) - - # apply encoder - hidden_state = self.layernorm_pre(hidden_state) - - # Compute the number of tokens to pad - num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 - # Compute padding tuple for pad function - padding = ( - 0, - 0, - 0, - num_padding_patches, - ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) - # Pad the tensor - hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) - slice_index = -num_padding_patches if num_padding_patches > 0 else None - - if attention_mask is not None: - attention_mask = attention_mask.reshape( - batch_size * num_concurrent_media, -1 - ) - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.dtype, - ) - - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) - hidden_state, all_intermediate_hidden_states = self.transformer( - hidden_state, - attention_mask=attention_mask, - ) - intermediate_hidden_states = [ - hidden_state - for idx, hidden_state in enumerate(all_intermediate_hidden_states) - if idx in self.intermediate_layers_indices - ] - intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) - - # apply global encoder - hidden_state = self.layernorm_post(hidden_state) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim, - ) - hidden_state = self.post_tile_positional_embedding( - hidden_state, aspect_ratio_ids - ) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles * (num_patches + num_padding_patches), - dim, - ) - hidden_state, _ = self.global_transformer( - hidden_state, attention_mask=attention_mask - ) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim, - ) - hidden_state = hidden_state[:, :, :slice_index] - - # adding intermediate layer outputs - hidden_state = hidden_state.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, dim - ) - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - -1, - ) - intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1 - ) - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) - return hidden_state - - -class MllamaTextCrossAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, *, prefix, config, weights, layer_idx): - super().__init__() - self.config = config - self.num_heads = self.config.num_attention_heads - self.num_key_value_heads = self.config.num_key_value_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.head_size = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.layer_idx = layer_idx - - self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - self.num_key_value_heads // weights.process_group.size() - ) - - self.q_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.q_proj", - weights=weights, - bias=False, - ) - self.k_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=False, - ) - self.v_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=False, - ) - self.o_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ) - - self.q_norm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps - ) - self.k_norm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps - ) - self.softmax_scale = self.head_size**-0.5 - - def forward( - self, - hidden_states: torch.Tensor, + cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, - # past_key_value=None, - # attention_mask: Optional[torch.Tensor] = None, - # cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - # hidden_states = hidden_states.unsqueeze(0) - # bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(-1, self.num_heads, self.head_size) - query_states = self.q_norm(query_states) - - ( - cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, - max_q, - max_k, - indices, - ) = cross_attention_states - - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) - value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) - key_states = self.k_norm(key_states) - - # key_states = key_states.repeat(1, self.num_key_value_groups, 1) - # value_states = value_states.repeat(1, self.num_key_value_groups, 1) - - causal = False - # logger.info( - # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" - # ) - attn_output = flash_attn_2_cuda.varlen_fwd( - query_states, - key_states, - value_states, - None, - cu_seqlen_q, - cu_seqlen_k, - None, - None, - None, # block_tables - None, - max_q, - max_k, - 0.0, - self.softmax_scale, - False, - causal, # Causal - -1, # window_size_left, - -1, - 0.0, # softcap - False, - None, - )[0] - attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) - - return attn_output - - -# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText -class MllamaTextMLP(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = True, + flash_attention_recompute: Optional[bool] = True, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 + The only differences are: + - add token_idx input + - add use_flash_attention and flash_attention_recompute + """ + full_text_row_masked_out_mask = kwargs.get( + "full_text_row_masked_out_mask", None ) - self.gate_up_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - weights=weights, - dim=0, - bias=False, - ) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - shape = x.shape - gate_up_states = self.gate_up_proj(x) - gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) - result = self.down_proj( - self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] - ) - return result - - -class FlashLlamaCrossLayer(torch.nn.Module): - """Cross-attention transformer block with tanh-gated attention and feedforward.""" - - def __init__(self, *, prefix, config, weights, index) -> None: - layer_idx = index - super().__init__() - self.cross_attn = MllamaTextCrossAttention( - prefix=f"{prefix}.cross_attn", - config=config, - weights=weights, - layer_idx=layer_idx, - ) - - self.input_layernorm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.cross_attn_attn_gate = torch.nn.Parameter( - weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False - ) - - self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.post_attention_layernorm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) - self.cross_attn_mlp_gate = torch.nn.Parameter( - weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False - ) - self.layer_idx = layer_idx - - def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - seqlen, - max_s, - adapter_data, - cross_attention_states, # [ IB, ...] - ) -> Tuple[torch.Tensor, torch.Tensor]: - if cross_attention_states is None: - return hidden_states, residual - if residual is not None: - hidden_states += residual - - indices = cross_attention_states[-1] - out_hidden_states = hidden_states[:] - if len(indices) > 0: - assert max(indices) < hidden_states.shape[0] - hidden_states = hidden_states[indices] - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states = self.cross_attn( - hidden_states=hidden_states, - # attention_mask=cross_attention_mask, - cross_attention_states=cross_attention_states, - ) - hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - - out_hidden_states[indices] = hidden_states - hidden_states = out_hidden_states - - return hidden_states, None - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText -class MllamaTextRMSNorm(nn.Module): - def __init__(self, weight, eps): - super().__init__() - self.weight = weight - self.variance_epsilon = eps - - @classmethod - def load(cls, *, prefix, weights, eps): - weight = nn.Parameter( - weights.get_tensor(f"{prefix}.weight"), requires_grad=False - ) - return cls(weight=weight, eps=eps) - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MllamaForConditionalGeneration(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - config.vision_config.quantize = None - config.vision_config.speculator = config.speculator - config.text_config.quantize = config.quantize - config.text_config.speculator = config.speculator - config.text_config._attn_implementation = "sdpa" - self.hidden_size = config.text_config.hidden_size - self.vision_model = MllamaVisionModel( - prefix="vision_model", config=config.vision_config, weights=weights - ) - self.multi_modal_projector = FastLinear.load( - prefix="multi_modal_projector", config=config, weights=weights, bias=True - ) - self.text_model = FlashLlamaForCausalLM( - prefix="language_model", config=config.text_config, weights=weights - ) - self.config = config - self.dtype = weights.dtype - self.device = weights.device - - def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): - if aspect_ratio_ids is None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # logger.info(f"PIxel values {pixel_values.shape}") - batch_size = pixel_values.shape[0] - vision_states = self.vision_model( - pixel_values, aspect_ratio_ids, aspect_ratio_mask - ) - cross_attention_states = self.multi_modal_projector(vision_states).reshape( - -1, vision_states.shape[-2], self.hidden_size - ) - _, _, h = cross_attention_states.shape - cross_attention_states = cross_attention_states.view(batch_size, -1, h) - # logger.info(f"cross {cross_attention_states.shape}") - return cross_attention_states - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor], - adapter_data: Optional[torch.Tensor] = None, - # XXX: Putting these as optional so that the cuda warmup calls can go through. - cross_attention_states: Optional[torch.Tensor] = None, - image_indices=None, - ): - if cross_attention_states is not None: - seqlen_q = len(image_indices) - n_images = cross_attention_states.shape[0] - seqlen_k = cross_attention_states.shape[1] - device = cross_attention_states.device - if cu_seqlen_prefill is not None: - offset = 0 - cu_q = [] - indices = [] - for index in image_indices: - cu_q.append(offset) - length = seqlen.input_lengths[index].item() - assert index < seqlen.cu_seqlen_q.shape[0] - input_ids_offset = seqlen.cu_seqlen_q[index] - indices.extend(range(input_ids_offset, input_ids_offset + length)) - offset += length - cu_q.append(offset) - cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) - - assert max(indices) < input_ids.shape[0] - - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - max_q = cu_seqlen_q[-1].item() - max_k = seqlen_k - else: - cu_seqlen_q = torch.arange( - seqlen_q + 1, device=device, dtype=torch.int32 - ) - seqlen_k = cross_attention_states.shape[1] - n_images = cross_attention_states.shape[0] - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - max_q = seqlen_q - max_k = seqlen_k - indices = image_indices[:] - - cross_attention_states = ( - cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, - max_q, - max_k, - indices, + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - outputs = self.text_model( + outputs = self.language_model( input_ids=input_ids, + attention_mask=attention_mask, position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) + logits = outputs[0] + if not return_dict: + output = (logits,) + outputs[1:] + return output + return outputs + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + """ + Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 + The only differences are: + - add token_idx handling + - add bucket_internal handling + - add use_flash_attention and flash_attention_recompute + """ + + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + cross_attention_mask=cross_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + else: + use_flash_attention = kwargs.get("use_flash_attention", True) + flash_attention_recompute = kwargs.get("flash_attention_recompute", True) + position_ids = kwargs.get("position_ids", None) + output_attentions = kwargs.get("output_attentions", None) + output_hidden_states = kwargs.get("output_hidden_states", None) + return_dict = kwargs.get("return_dict", None) + labels = kwargs.get("labels", None) + cross_attention_states = kwargs.get("cross_attention_states", None) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + bucket_internal = kwargs.get("bucket_internal", None) + + if past_key_values is not None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif bucket_internal and token_idx is not None: + # for the 1st token we can slice the inputs till token idx for the fwd pass. + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + if cross_attention_mask is not None: + cross_attention_mask = cross_attention_mask[:, :token_idx, ...] + + # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select( + position_ids, 1, token_idx - 1 + ) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone( + memory_format=torch.contiguous_format + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError( + "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" + ) + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + use_flash_attention=use_flash_attention, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector( + cross_attention_states + ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = ( + _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + token_idx=token_idx, + ) + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None: + if cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + :, :, cache_position + ] + elif past_key_values is not None: + if token_idx is not None: + cross_attention_mask = torch.index_select( + cross_attention_mask, -2, token_idx - 1 + ) + full_text_row_masked_out_mask = torch.index_select( + full_text_row_masked_out_mask, -2, token_idx - 1 + ) + else: + cross_attention_mask = cross_attention_mask[:, :, -1:] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + :, :, -1: + ] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format), + "inputs_embeds": None, + } + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + # keep cache_position implementation as None for HPU + cache_position = None + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "labels": labels, + "return_dict": kwargs.get("return_dict"), + "full_text_row_masked_out_mask": full_text_row_masked_out_mask, + "use_flash_attention": use_flash_attention, + "cross_attention_mask": cross_attention_mask, + "cross_attention_states": cross_attention_states, + "output_attentions": output_attentions, + "flash_attention_recompute": flash_attention_recompute, + } + ) + + return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index cef761b4..66e00171 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -8,7 +8,6 @@ from io import BytesIO from opentelemetry import trace from loguru import logger from typing import Iterable, Optional, Tuple, List, Type, Dict -import itertools import tempfile import copy from text_generation_server.models import Model @@ -19,7 +18,6 @@ from text_generation_server.models.causal_lm import ( CausalLMBatch, CausalLMRequest, remove_kv_cache_from_output, - biggest_single_chunk, ) from transformers.models.llava_next.modeling_llava_next import ( @@ -68,18 +66,19 @@ IDEFICS2_IMAGE_TOKEN = "" IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048)) MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -MAX_BATCH_SIZE = ( - int(os.environ.get("MAX_BATCH_SIZE")) - if os.environ.get("MAX_BATCH_SIZE") is not None - else None -) +max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") +if max_batch_size_str is not None: + MAX_BATCH_SIZE = int(max_batch_size_str) +else: + raise ValueError("MAX_BATCH_SIZE is not set") PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = [] DECODE_WARMUP_BATCH_SIZE_LIST = [] +CROSS_ATTENTION_LAYERS = [] def round_up(warmup_list: list, num): @@ -87,7 +86,7 @@ def round_up(warmup_list: list, num): for i in warmup_list: if num <= i: break - return i + return i if i > 0 else num def split(string) -> List[Dict[str, str]]: @@ -107,20 +106,17 @@ def split(string) -> List[Dict[str, str]]: return parts -def image_text_replacement(processor, image_input, config, image_id: int) -> str: +def image_text_replacement(config) -> str: if config.model_type == "idefics2": image_seq_len = 64 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" - if processor.image_processor.do_image_splitting: - image_str *= 5 return image_str elif config.model_type == "llava_next": - height, width = image_input["image_sizes"][image_id] - num_features = get_number_of_features(height, width, config) - return "" * num_features - + return "" elif config.model_type == "paligemma": - return "" * config.text_config.num_image_tokens + return "" + elif config.model_type == "mllama": + return "<|image|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -192,6 +188,100 @@ class VlmCausalLMBatch(CausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] + aspect_ratio_ids: Optional[torch.Tensor] = None + aspect_ratio_mask: Optional[torch.Tensor] = None + cross_attention_mask: Optional[torch.Tensor] = None + prefilling: bool = True + token_idx: torch.Tensor = None + + def __init__( + self, + batch_id, + requests, + input_ids, + attention_mask, + position_ids, + past_key_values, + merged_kv_cache, + next_token_chooser, + top_n_tokens, + top_n_tokens_tensor, + input_length, + pixel_values: Optional[List[torch.Tensor]] = None, + pixel_attention_mask: Optional[List[torch.Tensor]] = None, + image_sizes: Optional[List[Tuple[int, int]]] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + prefilling: Optional[bool] = True, + ): + super().__init__( + batch_id=batch_id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + merged_kv_cache=merged_kv_cache, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_length, + ) + + self.pixel_values = pixel_values + self.pixel_attention_mask = pixel_attention_mask + self.image_sizes = image_sizes + self.aspect_ratio_ids = aspect_ratio_ids + self.aspect_ratio_mask = aspect_ratio_mask + self.cross_attention_mask = cross_attention_mask + self.prefilling = prefilling + + @property + def token_idx(self): + if self.prefilling: + # no right padding for prefill + token_idx_scalar = self.attention_mask.shape[-1] - 1 + return torch.tensor(token_idx_scalar).to(self.attention_mask.device) + else: + token_idx_scalar = self.attention_mask.shape[-1] - self.right_padding + return torch.tensor(token_idx_scalar).to(self.attention_mask.device) + + def padding_process(self, pad_id: int): + # self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1) + right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1] + self.input_ids = torch.nn.functional.pad( + self.input_ids, (0, right_padding), value=pad_id + ) + self.attention_mask = torch.nn.functional.pad( + self.attention_mask, (0, right_padding), value=0 + ) + # if self.position_ids is not None: + # self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1 + if self.cross_attention_mask is not None: + self.cross_attention_mask = torch.nn.functional.pad( + self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0 + ) + if self.past is not None: + past_key_values_list = list(self.past_key_values) + for layer_id in range(len(self.past)): + past_key_value_list = list(self.past_key_values[layer_id]) + if layer_id not in CROSS_ATTENTION_LAYERS: + past_key_value_list[0] = torch.nn.functional.pad( + self.past_key_values[layer_id][0], + (0, 0, 0, right_padding), + value=0, + ) + past_key_value_list[1] = torch.nn.functional.pad( + self.past_key_values[layer_id][1], + (0, 0, 0, right_padding), + value=0, + ) + past_key_values_list[layer_id] = tuple(past_key_value_list) + self.past_key_values = tuple(past_key_values_list) + + self.prefilling = False + self.input_length = self.input_length @classmethod def from_tokenized( @@ -239,23 +329,23 @@ class VlmCausalLMBatch(CausalLMBatch): bucket_size = max_input_length left_padding = max_input_length - input_len if is_warmup is False: - if input_len < max_input_length: - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len + rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + bucket_size = rounded_seq_len - 1 + left_padding = bucket_size - input_len input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] + cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None) # Allocate space for first token - if left_padding > 0: - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + if cross_attention_mask is not None: + cross_attention_mask = torch.nn.functional.pad( + cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0 ) all_input_ids = torch.nn.functional.pad( input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id @@ -270,9 +360,13 @@ class VlmCausalLMBatch(CausalLMBatch): r.all_input_ids = all_input_ids[r.idx] input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) + cross_attention_mask = ( + cross_attention_mask.to(device) + if cross_attention_mask is not None + else None + ) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() return cls( @@ -287,6 +381,7 @@ class VlmCausalLMBatch(CausalLMBatch): top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, input_length=input_len, + cross_attention_mask=cross_attention_mask, ) @classmethod @@ -298,46 +393,40 @@ class VlmCausalLMBatch(CausalLMBatch): config, is_warmup, ): - # Process images first. We need all of them so that the processor - # can make the image splits the same size. And we need the final - # sizes to insert correct number of image tokens. + image_inputs = {} + texts = [] images = [] - for r in requests: + batch_tokenized_inputs = {} + + for i, r in enumerate(requests): + # Each input is encoded into a list, where each element of this input list is either a string or a URL + curr_text = "" + curr_image = None for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": - pass + curr_text += chunk.text elif chunk_type == "image": image = Image.open(BytesIO(chunk.image.data)) - if config.model_type == "llava_next": - images.append(image) - else: - images.append([image]) + # TODO unsure about BOS + curr_image = image else: raise RuntimeError(f"Invalid chunk type {chunk_type}") - image_inputs = None - if images: - image_inputs = processor.image_processor(images, return_tensors="pt") - - batch_inputs = [] - max_truncation = 0 - image_id = 0 - for r in requests: - full_text = "" - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += chunk.text - elif chunk_type == "image": - full_text += image_text_replacement( - processor, image_inputs, config, image_id + if image_text_replacement(config) not in curr_text: + if "" in curr_text: + curr_text = curr_text.replace( + "", image_text_replacement(config) ) - image_id += 1 - full_text = image_text_replacement_fixup(config, full_text) + else: + curr_text = image_text_replacement(config) + curr_text - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) + texts.append(curr_text) + if curr_image is not None: + if config.model_type == "mllama": + images.append([curr_image]) + else: + images.append(curr_image) missing_inputs = 0 dummy_images = None @@ -346,45 +435,48 @@ class VlmCausalLMBatch(CausalLMBatch): missing_inputs = new_bs - len(requests) if missing_inputs > 0: dummy_inputs = [] - if len(batch_inputs) > 0: - dummy_inputs = [batch_inputs[0]] * missing_inputs + if len(texts) > 0: + dummy_inputs = [texts[0]] * missing_inputs + dummy_images = [images[0]] * missing_inputs + texts += dummy_inputs + images += dummy_images - batch_inputs += dummy_inputs - - batch_tokenized_inputs = tokenizer( - batch_inputs, + processor_output = processor( + images, + texts, truncation=True, - max_length=max_truncation, - add_special_tokens=not config.model_type == "paligemma", + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, return_tensors="pt", + padding_side="left", padding="longest", - return_token_type_ids=False, ) - - if missing_inputs > 0 and image_inputs is not None: - dummy_shape = list(image_inputs["pixel_values"].shape) - dummy_shape[0] = missing_inputs - dummy_images = torch.rand(dummy_shape) - new_image_inputs = { - "pixel_values": torch.cat( - (image_inputs["pixel_values"], dummy_images), dim=0 - ), - } - if "pixel_attention_mask" in image_inputs: - dummy_shape = list(image_inputs["pixel_attention_mask"].shape) - dummy_shape[0] = missing_inputs - dummy_attention = torch.zeros(dummy_shape) - new_image_inputs["pixel_attention_mask"] = torch.cat( - (image_inputs["pixel_attention_mask"], dummy_attention), dim=0 - ) - if "image_sizes" in image_inputs: - dummy_shape = list(list(image_inputs["image_sizes"])[0]) - dummy_shape = missing_inputs * [dummy_shape] - dummy_sizes = torch.IntTensor(dummy_shape) - new_image_inputs["image_sizes"] = torch.cat( - (image_inputs["image_sizes"], dummy_sizes), dim=0 - ) - image_inputs = new_image_inputs + if "input_ids" in processor_output: + batch_tokenized_inputs.update({"input_ids": processor_output["input_ids"]}) + if "attention_mask" in processor_output: + batch_tokenized_inputs.update( + {"attention_mask": processor_output["attention_mask"]} + ) + if "cross_attention_mask" in processor_output: + batch_tokenized_inputs.update( + {"cross_attention_mask": processor_output["cross_attention_mask"]} + ) + if "pixel_values" in processor_output: + image_inputs.update({"pixel_values": processor_output["pixel_values"]}) + if "pixel_attention_mask" in processor_output: + image_inputs.update( + {"pixel_attention_mask": processor_output["pixel_attention_mask"]} + ) + if "aspect_ratio_ids" in processor_output: + image_inputs.update( + {"aspect_ratio_ids": processor_output["aspect_ratio_ids"]} + ) + if "aspect_ratio_mask" in processor_output: + image_inputs.update( + {"aspect_ratio_mask": processor_output["aspect_ratio_mask"]} + ) + if "image_sizes" in processor_output: + image_inputs.update({"image_sizes": processor_output["image_sizes"]}) return batch_tokenized_inputs, image_inputs @@ -402,7 +494,9 @@ class VlmCausalLMBatch(CausalLMBatch): batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( pb.requests, tokenizer, processor, config, is_warmup ) - batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + batch = cls.from_tokenized( + pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup + ) if image_inputs is not None: batch.pixel_values = image_inputs["pixel_values"].to(device=device) if "pixel_attention_mask" in image_inputs: @@ -415,10 +509,26 @@ class VlmCausalLMBatch(CausalLMBatch): batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.image_sizes = None + if "aspect_ratio_ids" in image_inputs: + batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to( + device=device + ) + else: + batch.aspect_ratio_ids = None + if "aspect_ratio_mask" in image_inputs: + batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( + device=device + ) + else: + batch.aspect_ratio_mask = None else: batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.aspect_ratio_ids = None + batch.aspect_ratio_mask = None + batch.cross_attention_mask = None + return batch @classmethod @@ -440,107 +550,231 @@ class VlmCausalLMBatch(CausalLMBatch): ) -> "VlmCausalLMBatch": if not all(b.past_key_values is not None for b in batches): raise ValueError("KV cache not allocated! Cannot recombine before prefill!") + # Used for padding total_requests = sum(len(b) for b in batches) new_bs = total_requests - if is_warmup is False: + if not is_warmup: new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) - batch_id = batches[0].batch_id - device = batches[0].input_ids.device - input_lengths = [b.input_length for b in batches] - max_input_length = max(input_lengths) - offsets = [max_input_length - b.input_length for b in batches] - - cur_padding = [b.right_padding for b in batches] - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - - moves_needed = [ - total_requests - len(b) if b.batch_size == new_bs else total_requests - for b in batches - ] - dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = batches[dst_batch_idx].batch_size < new_bs - - # TODO: Add support for changing max seq len, i.e. due to output length bucketing - # FIXME: max_seq_len for non optimized code if len(batches) > 1: scenario = "CONCAT" - elif reshape: - scenario = "RESHAPE" - elif cur_padding[dst_batch_idx] <= 0: + elif batches[0].prefilling: scenario = "SHIFT" - offsets = [ - biggest_single_chunk(b.max_input_length - max_input_length) - for b in batches - ] - max_input_length = max_input_length + offsets[dst_batch_idx] else: - # Nothing to do return batches[0] dbg_trace( scenario, f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}" - f" offsets:{offsets}" - f" input_lengths:{input_lengths}" - f" cur_padding:{cur_padding}" - f" dst_batch:{dst_batch_idx}", + f" reqs:{[len(b) for b in batches]}", ) - grouped_requests = [[req for req in batch.requests] for batch in batches] - flat_requests = list(itertools.chain(*grouped_requests)) + if scenario == "SHIFT": + batch = batches[0] + batch.padding_process(pad_token_id) + return batch - for i in range(len(batches)): - target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size - batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) - batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[i].split_kv_cache_if_needed(i == dst_batch_idx) - batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data( - [batches[i] for i in range(len(batches)) if i != dst_batch_idx] - ) + total_batch_size = 0 + max_input_length = 0 + for i, batch in enumerate(batches): + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.input_length) + # Batch attributes + requests = [] + input_lengths = [] + top_n_tokens = [] + parameters = [] + fsm_grammar_states = [] - top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + past_key_values = [] + top_n_tokens_tensor = None + cross_attention_mask = None + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + keep_indices = [] + for req in batch.requests: + keep_indices.append(req.idx) - parameters = [r.data.parameters for r in flat_requests] - # append the dummy parameters for dummy requests - batch_size = batches[dst_batch_idx].batch_size - parameters = pad_next_token_chooser_parameters(parameters, batch_size) + requests.extend(batch.requests) + parameters.extend([r.data.parameters for r in batch.requests]) + fsm_grammar_states.extend( + [batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices] + ) + input_lengths.extend([batch.input_length]) + top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices]) - # update past grammar states - fsm_grammar_states = [0] * batch_size - for batch in batches: - for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = ( - batch.next_token_chooser.fsm_grammar_states[i] + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS)) + # # Copy to correct indices + + left_offset = max_input_length - batch.input_length + right_padding = MAX_TOTAL_TOKENS - max_input_length + input_ids[start_index:end_index, left_offset:-right_padding] = ( + batch.input_ids[keep_indices, : batch.input_length] + ) + + # Create padded tensor + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + new_bs, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[ + keep_indices + ] + + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (new_bs, MAX_TOTAL_TOKENS), ) + attention_mask[ + start_index:end_index, + left_offset:-right_padding, + ] = batch.attention_mask[ + keep_indices, + : batch.input_length, + ] + + if batch.cross_attention_mask is not None: + cross_attention_mask_shape = list(batch.cross_attention_mask.shape) + cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS + cross_attention_mask_shape[0] = new_bs + cross_attention_mask_shape = torch.Size(cross_attention_mask_shape) + if cross_attention_mask is None: + cross_attention_mask = batch.cross_attention_mask.new_zeros( + cross_attention_mask_shape, + ) + cross_attention_mask[ + start_index:end_index, + left_offset:-right_padding, + ] = batch.cross_attention_mask[ + keep_indices, + : batch.input_length, + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((new_bs, 1)) + position_ids[start_index:end_index] = batch.position_ids[keep_indices, :] + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if isinstance(batch.past_key_values, tuple): + batch.past_key_values = [ + [t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values + ] + elif len(batch.past_key_values[0][0].shape) == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:]) + + start_index = end_index + + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + past_key_values = [] + for layer_id in range(len(batches[0].past_key_values)): + if layer_id in CROSS_ATTENTION_LAYERS: + padded_past_keys_shape = list( + batches[0].past_key_values[layer_id][0].shape + ) + padded_past_keys_shape[0] = new_bs + padded_past_keys_shape = torch.Size(padded_past_keys_shape) + else: + padded_past_keys_shape = ( + new_bs, + num_heads, + MAX_TOTAL_TOKENS, + head_dim, + ) + + padded_past_keys = first_past_kvs[layer_id][0].new_zeros( + padded_past_keys_shape + ) + padded_past_values = first_past_kvs[layer_id][1].new_zeros( + padded_past_keys_shape + ) + start_index = 0 + for batch in batches: + keep_indices = [] + for req in batch.requests: + keep_indices.append(req.idx) + + left_offset = max_input_length - batch.input_length + right_padding = MAX_TOTAL_TOKENS - max_input_length + past_keys = batch.past_key_values[layer_id][0] + past_values = batch.past_key_values[layer_id][1] + # Clear reference to the original tensor + batch.past_key_values[layer_id] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + if layer_id in CROSS_ATTENTION_LAYERS: + padded_past_keys[start_index:end_index, :, :, :] = past_keys[ + keep_indices, :, :, : + ] + padded_past_values[start_index:end_index, :, :, :] = past_values[ + keep_indices, :, :, : + ] + + else: + padded_past_keys[ + start_index:end_index, :, left_offset:-right_padding, : + ] = past_keys[keep_indices, :, : batch.input_length, :] + padded_past_values[ + start_index:end_index, :, left_offset:-right_padding, : + ] = past_values[keep_indices, :, : batch.input_length, :] + + start_index = end_index + + past_key_values.append(tuple([padded_past_keys, padded_past_values])) + past_key_values = tuple(past_key_values) + + batch_id = batches[0].batch_id + top_n_tokens.extend([-1] * (new_bs - total_batch_size)) + fsm_grammar_states.extend([-1] * (new_bs - total_batch_size)) + + for idx, req in enumerate(requests): + req.idx = idx + + parameters = pad_next_token_chooser_parameters(parameters, new_bs) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( parameters, - batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device, - batches[dst_batch_idx].next_token_chooser.tokenizer, + batches[0].next_token_chooser.dtype, + batches[0].next_token_chooser.device, + batches[0].next_token_chooser.tokenizer, fsm_grammar_states, quantization_enabled=hq_env.is_quantization_enabled, ) - - input_ids = batches[dst_batch_idx].input_ids - attention_mask = batches[dst_batch_idx].attention_mask - position_ids = batches[dst_batch_idx].position_ids - past_key_values = batches[dst_batch_idx].past_key_values input_length = max_input_length htorch.core.mark_step() return cls( batch_id=batch_id, - requests=flat_requests, + requests=requests, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -550,6 +784,13 @@ class VlmCausalLMBatch(CausalLMBatch): top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, input_length=input_length, + pixel_values=None, + pixel_attention_mask=None, + image_sizes=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=cross_attention_mask, + prefilling=False, ) @@ -601,6 +842,9 @@ class VlmCausalLM(Model): htorch.core.hpu_set_env() if world_size > 1: + os.environ.setdefault( + "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" + ) model = self.get_deepspeed_model(model_class, model_id, dtype, revision) model = hq_env.prepare_model_for_quantization(model) else: @@ -678,6 +922,11 @@ class VlmCausalLM(Model): self.kwargs["flash_attention_recompute"] = True self.speculate = get_speculate() + if model.config.model_type == "mllama": + global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS + CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers + BASE_IMAGE_TOKENS = 0 + super(VlmCausalLM, self).__init__( model_id=model_id, model=model, @@ -806,39 +1055,39 @@ class VlmCausalLM(Model): def forward( self, - input_ids, - attention_mask, - position_ids, - token_idx, - past_key_values: Optional[List[Tuple]] = None, - pixel_values: Optional[List[torch.Tensor]] = None, - image_sizes: Optional[List[Tuple[int, int]]] = None, + batch: VlmCausalLMBatch, bypass_hpu_graph: Optional[bool] = None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "token_idx": token_idx, - "pixel_values": pixel_values, - "image_sizes": image_sizes, + "input_ids": batch.input_ids, + "attention_mask": batch.attention_mask, + "past_key_values": batch.past_key_values, + "token_idx": batch.token_idx, + "pixel_values": batch.pixel_values, } + if self.model.config.model_type == "mllama": + kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids + kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask + kwargs["cross_attention_mask"] = batch.cross_attention_mask + else: + kwargs["image_sizes"] = batch.image_sizes + hpu_kwargs = {} # Optimum Habana got "lazy_mode" key-val only supported for llama type of models if self.model.config.model_type == "llama": hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 if self.has_position_ids: - kwargs["position_ids"] = position_ids - + kwargs["position_ids"] = batch.position_ids if bypass_hpu_graph is not None: hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph kwargs.update(self.kwargs) model_inputs = self.model.prepare_inputs_for_generation(**kwargs) - if past_key_values is not None: + + if batch.past_key_values is not None: return self.model.forward(**model_inputs, **hpu_kwargs) else: outputs = self.model.forward(**model_inputs, **hpu_kwargs) @@ -846,8 +1095,9 @@ class VlmCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batches: List[VlmCausalLMBatch], is_warmup: bool = False - ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: + self, batches: list[VlmCausalLMBatch], is_warmup: bool = False + ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() # Results generations: List[Generation] = [] @@ -927,9 +1177,18 @@ class VlmCausalLM(Model): # Update attention_mask as we added a new token to input_ids batch.attention_mask.index_fill_(1, token_idx, 1) + # add cross-attn mask for new token + if batch.cross_attention_mask is not None: + cross_attention_mask_prev = batch.cross_attention_mask + if token_idx is not None: + mask = cross_attention_mask_prev[ + :, token_idx - 2 : token_idx - 1, ... + ] + cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) + batch.cross_attention_mask = cross_attention_mask_prev + # Adjust lengths batch.input_length += 1 - # Update position_ids if prefill: batch.position_ids = ( @@ -955,7 +1214,7 @@ class VlmCausalLM(Model): # Check if we need to do any bookkeeping first if not prefill: - batch = batch.__class__.recombine( + batch = self.batch_type.recombine( [batch], self.tokenizer.pad_token_id, is_warmup ) @@ -977,38 +1236,34 @@ class VlmCausalLM(Model): # Execute batch if prefill: # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + # token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) batch.logits, batch.past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - batch.pixel_values, - batch.image_sizes, + batch, bypass_hpu_graph=( prefill and self.limit_hpu_graph if self.enable_hpu_graph else None ), ) + elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): # Don't schedule next forward if max_new_tokens for all requests equals 1 # - we've already generated the first and only needed token in the prefill phase pass else: - token_idx = torch.tensor( - batch.attention_mask.shape[-1] - batch.right_padding - ).to(self.device) + # token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) batch.logits = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, + batch, bypass_hpu_graph=( prefill and self.limit_hpu_graph if self.enable_hpu_graph else None ), ) + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.aspect_ratio_ids is not None: + batch.aspect_ratio_ids = None + if batch.aspect_ratio_mask is not None: + batch.aspect_ratio_mask = None + htorch.core.mark_step() start_decode = time.time_ns() @@ -1181,7 +1436,7 @@ class VlmCausalLM(Model): return generations, batch if not stopped else None, (forward_ns, decode_ns) def batch_from_pb(self, batch, is_warmup): - return VlmCausalLMBatch.from_pb_processor( + return self.batch_type.from_pb_processor( batch, self.tokenizer, self.processor, @@ -1204,22 +1459,22 @@ class VlmCausalLM(Model): def warmup( self, request: generate_pb2.WarmupRequest ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - is_warmup = True - batch = self.batch_from_pb(request.batch, is_warmup) + global MAX_TOTAL_TOKENS + MAX_TOTAL_TOKENS = request.max_total_tokens + batch = self.batch_from_pb(request.batch, is_warmup=True) + max_input_tokens = request.max_input_tokens + max_prefill_batch_size = batch.input_ids.shape[0] try: # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) except Exception: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) - global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - MAX_TOTAL_TOKENS = request.max_total_tokens - max_input_length = batch.input_ids.shape[1] - max_prefill_batch_size = batch.input_ids.shape[0] + global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST PREFILL_WARMUP_BATCH_SIZE_LIST = [] batch_size = 1 while batch_size <= max_prefill_batch_size: @@ -1228,15 +1483,19 @@ class VlmCausalLM(Model): if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size: PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - seq_len = BASE_IMAGE_TOKENS + if self.model.config.model_type == "mllama": + seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF + else: + seq_len = BASE_IMAGE_TOKENS + PREFILL_WARMUP_SEQLEN_LIST = [] i = 0 - while seq_len <= max_input_length: + while seq_len <= max_input_tokens: PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i) i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) + if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_tokens: + PREFILL_WARMUP_SEQLEN_LIST.append(max_input_tokens) # Prefill and decode warmup DECODE_WARMUP_BATCH_SIZE_LIST = [] @@ -1246,10 +1505,13 @@ class VlmCausalLM(Model): for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST: for seq_len in PREFILL_WARMUP_SEQLEN_LIST: batch = self.generate_warmup_batch( - request, seq_len, batch_size, is_warmup + request, seq_len, batch_size, is_warmup=True + ) + _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) + assert prefill_batch is not None + _, decode_batch, _ = self.generate_token( + [prefill_batch], is_warmup=True ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) @@ -1280,43 +1542,41 @@ class VlmCausalLM(Model): and batch_size <= max_decode_batch_size ): batches = [] - for i in range(int(batch_size / max_prefill_batch_size)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0], - DECODE_WARMUP_BATCH_SIZE_LIST[-1], - is_warmup, - ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) while batch_size <= max_decode_batch_size: - _, decode_batch, _ = self.generate_token(batches, is_warmup) + for i in range(int(batch_size / max_prefill_batch_size)): + batch = self.generate_warmup_batch( + request, + PREFILL_WARMUP_SEQLEN_LIST[0] - 1, + max_prefill_batch_size, + is_warmup=False, + ) + _, prefill_batch, _ = self.generate_token( + [batch], is_warmup=True + ) + batches.append(prefill_batch) + + _, decode_batch, _ = self.generate_token(batches, is_warmup=True) DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) batch_size = batch_size * 2 batches.clear() - for i in range(int(batch_size / max_prefill_batch_size)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0], - DECODE_WARMUP_BATCH_SIZE_LIST[-1], - is_warmup, - ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) - - batches.clear() if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2 batch_size = max_decode_batch_size for i in range(int(max_decode_batch_size / 2)): batch = self.generate_warmup_batch( - request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup + request, + PREFILL_WARMUP_SEQLEN_LIST[0] - 1, + 2, + is_warmup=False, + ) + _, prefill_batch, _ = self.generate_token( + [batch], is_warmup=True ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup) + _, decode_batch, _ = self.generate_token(batches, is_warmup=True) DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) + except Exception: raise RuntimeError( f"Not enough memory to handle batch_size({batch_size}) decode warmup." @@ -1333,7 +1593,7 @@ class VlmCausalLM(Model): ) max_supported_total_tokens = MAX_BATCH_SIZE * MAX_TOTAL_TOKENS - max_input_tokens = max_input_length + max_input_tokens = max_input_tokens max_total_tokens = MAX_TOTAL_TOKENS return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx index c7f73618..c3481f2e 100644 --- a/docs/source/backends/gaudi.mdx +++ b/docs/source/backends/gaudi.mdx @@ -291,6 +291,8 @@ The following table contains the environment variables that can be used to confi Contributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md). +**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder. + ### Building the Docker Image from Source To build the Docker image from source: