From ac6fc70c757f1464c86e80ab293bae961e293b2d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 21 Mar 2025 11:22:12 +0000 Subject: [PATCH] Add support for other vlm --- .../text_generation_server/models/__init__.py | 204 +++++++---- .../models/transformers_flash_vlm.py | 318 ++++++++++++------ 2 files changed, 361 insertions(+), 161 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2b6ea31a..49281f0a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -208,6 +208,8 @@ try: ) from text_generation_server.models.transformers_flash_vlm import ( TransformersFlashVlmCausalLM, + TransformersQwen2VlmCausalLM, + TransformersGemma3VlmCausalLM, ) except ImportError as e: log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}") @@ -1161,7 +1163,6 @@ def get_model( ) elif model_type == GEMMA3: if FLASH_ATTENTION: - # TODO: Use VlmCausalLM when image support is added. return VlmCausalLM( model_id=model_id, model_class=Gemma3ForConditionalGeneration, @@ -1179,9 +1180,11 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) elif FLASH_TRANSFORMERS_BACKEND: - return TransformersFlashVlmCausalLM.fallback( + from transformers import Gemma3ForConditionalGeneration as Gemma3Model + + return TransformersGemma3VlmCausalLM.fallback( model_id, - # AutoModelForConditionalGeneration, + Gemma3Model, revision, quantize=quantize, speculator=speculator, @@ -1490,42 +1493,60 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == QWEN2_VL: - return TransformersFlashVlmCausalLM.fallback( - model_id, - # AutoModelForConditionalGeneration, - revision, - quantize=quantize, - speculator=speculator, - dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - ) - return VlmCausalLM( - model_id=model_id, - model_class=Qwen2VLForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - default_dtype=torch.bfloat16, - kv_cache_dtype=kv_cache_dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif FLASH_TRANSFORMERS_BACKEND: + from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel + + return TransformersQwen2VlmCausalLM.fallback( + model_id, + Qwen2VLModel, + revision, + quantize=quantize, + speculator=speculator, + dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + ) if model_type == QWEN2_5_VL: - return VlmCausalLM( - model_id=model_id, - model_class=Qwen2_5VLForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - default_dtype=torch.bfloat16, - kv_cache_dtype=kv_cache_dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - config_class=Qwen2_5_VLConfig, - processor_class=Qwen2_5_VLProcessor, - ) + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2_5VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=Qwen2_5_VLConfig, + processor_class=Qwen2_5_VLProcessor, + ) + elif FLASH_TRANSFORMERS_BACKEND: + + return TransformersQwen2VlmCausalLM.fallback( + model_id, + Qwen2VLModel, + revision, + quantize=quantize, + speculator=speculator, + dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + config_class=Qwen2_5_VLConfig, + processor_class=Qwen2_5_VLProcessor, + ) if model_type == MLLAMA: if FLASH_ATTENTION: return MllamaCausalLM( @@ -1540,6 +1561,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + from transformers import MllamaForConditionalGeneration as MllamaModel + + return TransformersFlashVlmCausalLM.fallback( + model_id, + MllamaModel, + revision, + quantize=quantize, + speculator=speculator, + dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + batch_class=MllamaCausalLMBatch, + ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) if model_type == IDEFICS2: @@ -1558,6 +1592,19 @@ def get_model( # VRAM usage. processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) + elif FLASH_TRANSFORMERS_BACKEND: + from transformers import Idefics2ForConditionalGeneration as Idefics2Model + + return TransformersFlashVlmCausalLM.fallback( + model_id, + Idefics2Model, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS3: @@ -1576,37 +1623,52 @@ def get_model( # VRAM usage. processor_kwargs={"size": {"longest_edge": 1456}}, ) + elif FLASH_TRANSFORMERS_BACKEND: + from transformers import Idefics3ForConditionalGeneration as Idefics3Model + + return TransformersFlashVlmCausalLM.fallback( + model_id, + Idefics3Model, + revision, + quantize=quantize, + speculator=speculator, + dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + processor_kwargs={"size": {"longest_edge": 1456}}, + ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: - # return TransformersFlashVlmCausalLM.fallback( - # model_id, - # # AutoModelForConditionalGeneration, - # revision, - # quantize=quantize, - # speculator=speculator, - # dtype=torch.bfloat16, - # trust_remote_code=trust_remote_code, - # batch_class=PaliGemmaBatch, - # ) - # if FLASH_ATTENTION: - return VlmCausalLM( - model_id=model_id, - model_class=PaliGemmaForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - kv_cache_dtype=kv_cache_dtype, - # Works better for these models - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, - ) - # else: - # raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, + ) + elif FLASH_TRANSFORMERS_BACKEND: + from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel + return TransformersFlashVlmCausalLM.fallback( + model_id, + PaliGemmaModel, + revision, + quantize=quantize, + speculator=speculator, + dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + batch_class=PaliGemmaBatch, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: return VlmCausalLM( @@ -1619,6 +1681,18 @@ def get_model( kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, ) + elif FLASH_TRANSFORMERS_BACKEND: + from transformers import LlavaNextForConditionalGeneration as LlavaNextModel + + return TransformersFlashVlmCausalLM.fallback( + model_id, + LlavaNextModel, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index d4945ef0..aea2b8a8 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -17,6 +17,15 @@ import torch.nn.functional as F tracer = trace.get_tracer(__name__) +# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, +# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache +# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due +# to internal constraints it was not (yet?) possible to circumvent +REPLICATED_ATTENTION_MODELS = [ + "olmo2", + "phi3", +] + def tgi_flash_attention_forward( module, @@ -35,17 +44,13 @@ def tgi_flash_attention_forward( softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, + use_sdpa: Optional[bool] = False, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): - from loguru import logger - - logger.info("Using TGI Flash Attention") - # from pdb import set_trace; set_trace() kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) - # from pdb import set_trace; set_trace() # Take care of updating the cache in-place kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) @@ -55,10 +60,8 @@ def tgi_flash_attention_forward( sliding_window = -1 if sliding_window is None else sliding_window # if module.layer_idx == 0: # from pdb import set_trace; set_trace() - if cu_seqlen_prefill is not None: - attention_mask = None - if attention_mask is None: + if not use_sdpa: attn_output = attention( query=query_states, key=key_states, @@ -72,9 +75,6 @@ def tgi_flash_attention_forward( softcap=softcap, ) else: - from loguru import logger - - logger.info("uSING FLASH ATTENTION") lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1] max_length = max(lengths) attention_mask = attention_mask[:, :, :, :max_length] @@ -135,34 +135,62 @@ def tgi_flash_attention_forward( transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward + +# Siglip transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = ( transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"] ) -# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"] -# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"] -transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = ( - tgi_flash_attention_forward -) + +# Qwen2VL transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ "tgi" ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ "eager" ] +# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS +# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward -# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, -# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache -# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due -# to internal constraints it was not (yet?) possible to circumvent -REPLICATED_ATTENTION_MODELS = [ - "olmo2", - "phi3", +# Idefics2 +transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[ + "tgi" +] = transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[ + "eager" ] +transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[ + "tgi" +] = transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[ + "eager" +] + +# Idefics3 +transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[ + "tgi" +] = transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[ + "eager" +] + +# Clip +transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["tgi"] = ( + transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["sdpa"] +) + +# Mllama +transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["tgi"] = ( + transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["eager"] +) +# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS +# transformers.models.mllama.modeling_mllama.MLLAMA_TEXT_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward +# transformers.models.mllama.modeling_mllama.MLLAMA_CROSS_ATTENTION_CLASSES["tgi"] = tgi_cross_attention_forward + +# TODO: implement +# tgi_cross_attention_forward class TransformersFlashVlmCausalLM(VlmCausalLM): def __init__( self, model_id: str, + model_class, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -175,10 +203,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): kv_cache_dtype: Optional[torch.dtype] = None, batch_class=VlmCausalLMBatch, ): - # # from pdb import set_trace; set_trace() - self.batch_class = VlmCausalLMBatch + self.batch_class = batch_class self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() + self.dtype = dtype if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") @@ -204,16 +232,15 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): if processor_kwargs is None: processor_kwargs = {} - # processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}} + self.processor = processor_class.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code, **processor_kwargs, ) - from transformers import Qwen2VLForConditionalGeneration - model = Qwen2VLForConditionalGeneration.from_pretrained( + model = model_class.from_pretrained( model_id, revision=revision, torch_dtype=dtype, @@ -318,6 +345,100 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): torch.distributed.barrier(group=self.process_group) + def get_position_ids(self, input_ids, image_grid_thw, position_ids): + return position_ids + + def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + return { + "input_ids": input_ids.unsqueeze(0), + "position_ids": position_ids.unsqueeze(0), + } + + def post_process_outputs(self, logits, lm_head_indices): + return logits.squeeze(dim=0) + + @classmethod + def fallback( + cls, + model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + batch_class: Optional[type] = VlmCausalLMBatch, + processor_kwargs: Optional[dict] = None, + ): + return cls( + model_id=model_id, + model_class=model_class, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + batch_class=batch_class, + processor_kwargs=processor_kwargs, + ) + + def _model_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[KVCache], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + lm_head_indices: Optional[torch.Tensor], + prefill_cache_indices=None, # not used, but passed to match original signature + adapter_data=None, # not supported, but passed to match original signature + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + ): + # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers + logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 + + inputs = self.pre_process_inputs( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + ) + + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ + logits = self.model.original_forward( + input_ids=inputs["input_ids"], + position_ids=inputs["position_ids"], + past_key_values=None, # we use self.kv_cache instead of transformers cache object + use_cache=False, # we use self.kv_cache instead of transformers cache object + logits_to_keep=logits_to_keep, + return_dict=True, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + kv_head_mapping=self.kv_head_mapping, + kv_scales=self.kv_scales, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + attention_mask=inputs.get("attention_mask", None), + use_sdpa=inputs.get("use_sdpa", False), + ).logits + + logits = self.post_process_outputs(logits, lm_head_indices) + + return logits, None + + +class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM): def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor): if image_grid_thw is None: return ( @@ -391,77 +512,82 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): ) return position_ids - @classmethod - def fallback( - cls, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - batch_class: Optional[type] = VlmCausalLMBatch, - ): - return cls( - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - batch_class=batch_class, - ) + def post_process_outputs(self, logits, lm_head_indices): + return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0) - def _model_forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[KVCache], - block_tables: torch.Tensor, - slots: torch.Tensor, - seqlen: Seqlen, - max_s: int, - lm_head_indices: Optional[torch.Tensor], - prefill_cache_indices=None, # not used, but passed to match original signature - adapter_data=None, # not supported, but passed to match original signature - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, - ): - # from pdb import set_trace; set_trace() - # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers - logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 + def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + input_ids = input_ids.unsqueeze(0) + position_ids = position_ids.transpose(0, 1).unsqueeze(1) + return {"input_ids": input_ids, "position_ids": position_ids} - # This is equivalent to `self.model.forward`, see the monkey patch in __init__ - logits = ( - self.model.original_forward( - input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers - position_ids=position_ids.transpose(0, 1).unsqueeze( - 1 - ), # expand dim to fit Transformers - past_key_values=None, # we use self.kv_cache instead of transformers cache object - use_cache=False, # we use self.kv_cache instead of transformers cache object - logits_to_keep=logits_to_keep, - return_dict=True, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - kv_head_mapping=self.kv_head_mapping, - kv_scales=self.kv_scales, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_sizes=image_sizes, - image_grid_thw=image_grid_thw, + +class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): + def get_attention_mask(self, input_ids, cu_seqlen_prefill): + device = input_ids.device + dtype = self.dtype + min_dtype = torch.finfo(dtype).min + + lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist() + batch_size = len(lengths) + + sequence_length = max(lengths) + target_length = sequence_length + # Create the padding mask from the computed lengths. + # pad_mask: [batch, sequence_length] where True indicates valid tokens. + seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) + lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1) + pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length] + + # Build the base causal mask (for non-image tokens): + causal_mask = torch.tril( + torch.ones( + (sequence_length, sequence_length), dtype=torch.bool, device=device ) - .logits.squeeze(dim=0)[lm_head_indices] - .unsqueeze(0) + ) + base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze( + 1 + ) # [batch, sequence_length, sequence_length] + base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint + + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device ) - # from pdb import set_trace; set_trace() + image_token_mask = torch.nn.utils.rnn.pad_sequence( + torch.split(image_token_mask, lengths), batch_first=True, padding_value=0 + ) + bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze( + 1 + ) - return logits, None + # Combine the causal base mask and the bidirectional mask. + combined_mask = torch.logical_or( + base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1) + ).to(device) + # combined_mask now has shape [batch, 1, sequence_length, sequence_length] + + full_attention_mask = torch.zeros( + (batch_size, 1, sequence_length, target_length), + device=device, + dtype=torch.bool, + ) + full_attention_mask[:, :, :, :sequence_length] = combined_mask + + final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) + + return final_attention_mask + + def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + inputs = { + "input_ids": input_ids.unsqueeze(0), + "position_ids": position_ids.unsqueeze(0), + } + + if cu_seqlen_prefill is not None: + attention_mask = self.get_attention_mask( + input_ids.squeeze(0), cu_seqlen_prefill + ) + inputs["attention_mask"] = attention_mask + inputs["use_sdpa"] = True + + return inputs