diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 9a20f7b5..de7641e3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -62,7 +62,9 @@ class Qwen2Attention(torch.nn.Module): ): super().__init__() self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 + config.sliding_window + if config.use_sliding_window and config.sliding_window is not None + else -1 ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 550b18ec..f3f52496 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1510,6 +1510,8 @@ class FlashCausalLM(Model): if getattr(config, "sliding_window", None) is None: config.sliding_window = None + if getattr(config, "use_sliding_window", True) is False: + config.sliding_window = None self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index c179d5b1..0cd49d45 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -1059,17 +1059,6 @@ class FlashVlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - kwargs = {} - if htorch.utils.internal.is_lazy(): - batch_size = input_lengths.shape[0] - seqlen = ( - input_ids.shape[0] // batch_size - if batch.prefilling - else batch.hpu_attn_meta.block_list.shape[0] - ) - kwargs["bypass_hpu_graphs"] = not self.use_graphs( - batch.prefilling, seqlen, batch_size - ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots @@ -1082,6 +1071,26 @@ class FlashVlmCausalLM(FlashCausalLM): seqlen = Seqlen( input_lengths=_async_h2d_tensor_copy(input_lengths), ) + kwargs = {} + batch_size = input_lengths.shape[0] + prompt_len = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, prompt_len, batch_size + ) + if self.sliding_window is not None: + attn_mask = seqlen.make_sliding_window_bias( + input_lengths.tolist(), + self.sliding_window, + self.dtype, + prompt_len, + batch_size, + ) + seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask) logits, speculative_logits = self.model.forward( inputs_embeds=inputs_embeds, position_ids=_async_h2d_tensor_copy(position_ids), diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index b1f5232c..d266aad9 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -282,43 +282,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_mapping=None, attn_bias=None, ) - if self.sliding_window is not None: - block_tables_in_window = [] - for i, bt in enumerate(block_tables): - block_num_in_window = ( - self.sliding_window + BLOCK_SIZE - 1 - ) // BLOCK_SIZE - block_tables_in_window.append( - bt[max(0, blocks[i] - block_num_in_window) : blocks[i]] - ) - slots_in_window = [] - start_idx = 0 - for i, indice in enumerate(slot_indices): - mask = ( - indice - torch.arange(start_idx, indice + 1) - ) < self.sliding_window - slots_in_window.append(torch.arange(start_idx, indice + 1)[mask]) - start_idx += blocks[i] * BLOCK_SIZE - slots_in_window = torch.cat(slots_in_window, dim=0) - ( - block_list_in_window, - block_groups_in_window, - block_usage_in_window, - slots_in_window_mask, - _, - ) = generate_block_metadata( - self.dtype, - self.use_contiguous_pa, - slots, - block_tables_in_window, - self.bucketing_ctx, - slots_in_window, - block_bucket_size, - ) - meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window) - meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window) - meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window) - meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask) hpu_attention_meta = trim_attn_metadata(meta) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.