From 429dcd9c64a1199aa4f8bdfa45022112c5210f70 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 1 Jul 2025 16:06:01 +0800 Subject: [PATCH] [gaudi] Gemma3 sliding window support (#3280) Signed-off-by: Wang, Yi A --- .../layers/attention/common.py | 57 +++++ .../layers/attention/hpu.py | 40 ++- .../custom_modeling/flash_gemma2_modeling.py | 1 + .../custom_modeling/flash_gemma3_modeling.py | 8 +- .../custom_modeling/flash_mistral_modeling.py | 1 + .../custom_modeling/flash_qwen2_modeling.py | 5 +- .../custom_modeling/flash_qwen3_modeling.py | 1 + .../flash_qwen3_moe_modeling.py | 1 + .../flash_starcoder2_modeling.py | 1 + .../models/flash_causal_lm.py | 233 ++++++++++++++---- .../models/flash_vlm_causal_lm.py | 103 ++++++-- .../models/mllama_causal_lm.py | 36 ++- 12 files changed, 389 insertions(+), 98 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 5e03cd44..1086c411 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import torch from typing import Optional, List, Dict import collections +import torch.nn.functional as F _TYPE_CACHE = {} @@ -15,6 +16,12 @@ class HPUPagedAttentionMetadata: block_usage: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor] + slots_in_window_mask: Optional[torch.Tensor] = None + block_list_in_window: Optional[torch.Tensor] = None + block_mapping_in_window: Optional[torch.Tensor] = None + block_usage_in_window: Optional[torch.Tensor] = None + block_groups_in_window: Optional[torch.Tensor] = None + attn_bias_in_window: Optional[torch.Tensor] = None def subtuple( @@ -67,6 +74,12 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: "block_usage", "block_groups", "attn_bias", + "slots_in_window_mask", + "block_list_in_window", + "block_mapping_in_window", + "block_usage_in_window", + "block_groups_in_window", + "attn_bias_in_window", ], ) return attention_metadata @@ -75,6 +88,7 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: @dataclass class Seqlen: input_lengths: torch.Tensor + attn_mask: Optional[torch.Tensor] = None def __init__( self, @@ -86,6 +100,48 @@ class Seqlen: # Flash decoding doesn't need to clamp return self + def make_sliding_window_bias( + self, + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, + padded_input_len: Optional[int], + padded_bs: Optional[int], + ) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + if seq_len != 0: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = F.pad( + mask, + ( + padded_input_len - seq_len, + 0, + padded_input_len - seq_len, + 0, + 0, + 0, + ), + value=0, + ) + else: + mask = torch.full( + (1, padded_input_len, padded_input_len), + dtype=dtype, + fill_value=0, + ) + attn_biases.append(mask) + attn_biases = torch.stack(attn_biases, dim=0) + return attn_biases.to(torch.bool) + def _async_h2d_tensor_copy(source, device="hpu"): if source is None: @@ -124,6 +180,7 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object: "TrimmedSeqlen", [ "input_lengths", + "attn_mask", ], ) return attention_metadata diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index f12005d2..d3588e25 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -94,13 +94,13 @@ def attention( query, key, value, - attn_mask=None, + attn_mask=seqlen.attn_mask if window_size_left != -1 else None, dropout_p=0.0, - is_causal=causal, + is_causal=causal if window_size_left == -1 else False, scale=softmax_scale, softmax_mode="None", recompute_mode=None, - valid_sequence_lengths=seqlen.input_lengths, + valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None, padding_side="left", ) attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() @@ -119,6 +119,15 @@ def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size) hpu_attention_meta = hpu_attention_meta._replace( attn_bias=attn_bias, block_mapping=block_mapping.to(dtype) ) + if hpu_attention_meta.block_groups_in_window is not None: + block_mapping = torch.nn.functional.one_hot( + hpu_attention_meta.block_groups_in_window, num_classes=batch_size + ) + attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float()) + hpu_attention_meta = hpu_attention_meta._replace( + attn_bias_in_window=attn_bias, + block_mapping_in_window=block_mapping.to(dtype), + ) return hpu_attention_meta @@ -132,6 +141,7 @@ def paged_attention( kv_scales: KVScales, softcap: Optional[float] = None, hpu_attention_meta: HPUPagedAttentionMetadata, + window_size_left: int = -1, ): batch_size, head_num, head_size = query.shape fp8_kv = kv_cache.dtype == torch.float8_e4m3fn @@ -139,10 +149,26 @@ def paged_attention( query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, value_cache=kv_cache.value, - block_list=hpu_attention_meta.block_list, - block_mapping=hpu_attention_meta.block_mapping, - block_bias=hpu_attention_meta.attn_bias, - block_groups=hpu_attention_meta.block_groups, + block_list=( + hpu_attention_meta.block_list + if window_size_left == -1 + else hpu_attention_meta.block_list_in_window + ), + block_mapping=( + hpu_attention_meta.block_mapping + if window_size_left == -1 + else hpu_attention_meta.block_mapping_in_window + ), + block_bias=( + hpu_attention_meta.attn_bias + if window_size_left == -1 + else hpu_attention_meta.attn_bias_in_window + ), + block_groups=( + hpu_attention_meta.block_groups + if window_size_left == -1 + else hpu_attention_meta.block_groups_in_window + ), block_size=BLOCK_SIZE, scale=softmax_scale, matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 74d9397e..6ab1c4a9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -288,6 +288,7 @@ class FlashGemma2Attention(torch.nn.Module): softcap=self.softcap, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.window_size, ) return self.o_proj( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 7b789d30..c7091f90 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -135,9 +135,6 @@ class FlashGemma3Attention(torch.nn.Module): self.causal = causal if is_sliding: self.window_size = config.sliding_window - # TODO: remove this hack to support local sliding window - config = copy.deepcopy(config) - config.rope_scaling = dict(rope_type="default") self.rotary_emb = local_rotary_emb else: self.window_size = -1 @@ -267,6 +264,7 @@ class FlashGemma3Attention(torch.nn.Module): softcap=self.softcap, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.window_size, ) return self.o_proj( @@ -425,8 +423,10 @@ class FlashGemma3Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + local_config = copy.deepcopy(config) + local_config.rope_scaling = dict(rope_type="default") local_rotary_emb = PositionRotaryEmbedding.static( - config=config, + config=local_config, dim=config.head_dim, base=config.rope_local_base_freq, device=weights.device, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index f7aed118..43584d91 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -224,6 +224,7 @@ class MistralAttention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.max_past, ) return self.o_proj( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 13e1c916..de7641e3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -62,7 +62,9 @@ class Qwen2Attention(torch.nn.Module): ): super().__init__() self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 + config.sliding_window + if config.use_sliding_window and config.sliding_window is not None + else -1 ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -150,6 +152,7 @@ class Qwen2Attention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.max_past, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 63ee4c97..8ffbde98 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -167,6 +167,7 @@ class Qwen3Attention(nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.max_past, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py index da474adc..9d293eab 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py @@ -190,6 +190,7 @@ class Qwen3MoeAttention(nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.max_past, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 45baf4db..b36ead7d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -280,6 +280,7 @@ class Starcoder2Attention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + window_size_left=self.max_past, ) return self.o_proj( diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ca62560e..f3f52496 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -81,8 +81,14 @@ from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) -def prepare_for_decode( - dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx +def generate_block_metadata( + dtype, + use_contiguous_pa, + slots, + block_tables, + bucketing_ctx, + slots_in_window=None, + block_bucket_size=None, ): # Prepare values if we need to continue decoding # need for HPUPagedAttentionMetadata preparation @@ -112,11 +118,12 @@ def prepare_for_decode( assert len(block_list) == len(block_groups) assert len(block_list) == len(block_usage) if use_contiguous_pa: - block_bucket_size = max(max(block_list) + 1, len(block_list)) - if bucketing_ctx is not None: - block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( - block_bucket_size - ) + if block_bucket_size is None: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -125,30 +132,38 @@ def prepare_for_decode( block_groups = gather_list(block_groups, indices, -1) block_usage = gather_list(block_usage, indices, 1) else: - block_bucket_size = len(block_list) - if bucketing_ctx is not None: - block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( - block_bucket_size - ) + if block_bucket_size is None: + block_bucket_size = len(block_list) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) block_list = pad_list(block_list, block_bucket_size, 0) block_groups = pad_list(block_groups, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) + slots_in_window_mask = None + if slots_in_window is not None: + slot_list = [ + block_id * BLOCK_SIZE + slot_idx + for block_id in block_list + for slot_idx in range(BLOCK_SIZE) + ] + slot_list = torch.tensor(slot_list, dtype=torch.int64) + slot_list = slot_list.view(-1, BLOCK_SIZE) + slots_in_window_mask = torch.isin(slot_list, slots_in_window) + for i in range(slots_in_window_mask.shape[0]): + if not slots_in_window_mask[i].any(): + slots_in_window_mask[i, 0] = True block_list = torch.tensor(block_list, dtype=torch.int, device="cpu") block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu") block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu") - block_list_device = _async_h2d_tensor_copy(block_list) - block_groups_device = _async_h2d_tensor_copy(block_groups) - block_usage_device = _async_h2d_tensor_copy(block_usage) - - return trim_attn_metadata( - HPUPagedAttentionMetadata( - block_list=block_list_device, - block_groups=block_groups_device, - block_usage=block_usage_device, - block_mapping=None, - attn_bias=None, - ) + return ( + block_list, + block_groups, + block_usage, + slots_in_window_mask, + block_bucket_size, ) @@ -962,7 +977,9 @@ class FlashCausalLMBatch(Batch): valid_indices=None, ) - def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id): + def prepare_for_decode( + self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id, sliding_window + ): block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths] block_tables = [] for i, bt in enumerate(self.block_tables): @@ -975,15 +992,65 @@ class FlashCausalLMBatch(Batch): padded_bs = self.input_ids.shape[0] slots = self.slots[self.slot_indices] - self.hpu_attn_meta = prepare_for_decode( - dtype, - use_contiguous_pa, - "hpu", - slots, - block_tables, - padded_bs, - bucketing_ctx, + block_list, block_groups, block_usage, _, block_bucket_size = ( + generate_block_metadata( + dtype, + use_contiguous_pa, + slots, + block_tables, + bucketing_ctx, + ) ) + meta = HPUPagedAttentionMetadata( + block_list=_async_h2d_tensor_copy(block_list), + block_groups=_async_h2d_tensor_copy(block_groups), + block_usage=_async_h2d_tensor_copy(block_usage), + block_mapping=None, + attn_bias=None, + ) + if sliding_window is not None: + block_tables_in_window = [] + for i, bt in enumerate(self.block_tables): + block_num_in_window = ( + sliding_window + 2 * BLOCK_SIZE - 2 - slots[i] % BLOCK_SIZE + ) // BLOCK_SIZE + block_tables_in_window.append( + bt[max(0, block_num[i] - block_num_in_window) : block_num[i]] + ) + slots_in_window = [] + for i, indice in enumerate(self.slot_indices): + start_idx = indice - self.cache_lengths[i] + mask = ( + indice + - torch.arange( + start_idx, + indice + 1, + device=self.slots.device, + ) + ) < sliding_window + slots_in_window.append(self.slots[start_idx : indice + 1][mask]) + slots_in_window = torch.cat(slots_in_window, dim=0) + ( + block_list_in_window, + block_groups_in_window, + block_usage_in_window, + slots_in_window_mask, + _, + ) = generate_block_metadata( + dtype, + use_contiguous_pa, + slots, + block_tables_in_window, + bucketing_ctx, + slots_in_window, + block_bucket_size, + ) + meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window) + meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window) + meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window) + meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask) + + self.hpu_attn_meta = trim_attn_metadata(meta) self.input_ids = F.pad( self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id ) @@ -1443,6 +1510,8 @@ class FlashCausalLM(Model): if getattr(config, "sliding_window", None) is None: config.sliding_window = None + if getattr(config, "use_sliding_window", True) is False: + config.sliding_window = None self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() @@ -1865,6 +1934,15 @@ class FlashCausalLM(Model): kwargs["bypass_hpu_graphs"] = not self.use_graphs( True, prompt_len, batch_size ) + if self.sliding_window is not None: + attn_mask = seqlen.make_sliding_window_bias( + input_lengths.tolist(), + self.sliding_window, + self.dtype, + prompt_len, + batch_size, + ) + seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1885,17 +1963,17 @@ class FlashCausalLM(Model): position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size - past_len = [] block_tables = [] slots = [] start_idx = 0 + slot_indices = [] # fetch the last blocked to warmup block num for i in range(batch_size): block_array = list(range(start_idx, start_idx + blocks[i])) slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1) block_tables.append(block_array) - past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] input_lengths = torch.ones(batch_size, dtype=torch.int32) cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32) @@ -1904,16 +1982,61 @@ class FlashCausalLM(Model): seqlen = Seqlen( input_lengths=_async_h2d_tensor_copy(input_lengths), ) - - hpu_attention_meta = prepare_for_decode( - self.dtype, - self.use_contiguous_pa, - self.device, - slots, - block_tables, - batch_size, - bucketing_ctx=None, + block_list, block_groups, block_usage, _, block_bucket_size = ( + generate_block_metadata( + self.dtype, + self.use_contiguous_pa, + slots, + block_tables, + self.bucketing_ctx, + ) ) + meta = HPUPagedAttentionMetadata( + block_list=_async_h2d_tensor_copy(block_list), + block_groups=_async_h2d_tensor_copy(block_groups), + block_usage=_async_h2d_tensor_copy(block_usage), + block_mapping=None, + attn_bias=None, + ) + if self.sliding_window is not None: + block_tables_in_window = [] + for i, bt in enumerate(block_tables): + block_num_in_window = ( + self.sliding_window + BLOCK_SIZE - 1 + ) // BLOCK_SIZE + block_tables_in_window.append( + bt[max(0, blocks[i] - block_num_in_window) : blocks[i]] + ) + slots_in_window = [] + start_idx = 0 + for i, indice in enumerate(slot_indices): + mask = ( + indice - torch.arange(start_idx, indice + 1) + ) < self.sliding_window + slots_in_window.append(torch.arange(start_idx, indice + 1)[mask]) + start_idx += blocks[i] * BLOCK_SIZE + slots_in_window = torch.cat(slots_in_window, dim=0) + ( + block_list_in_window, + block_groups_in_window, + block_usage_in_window, + slots_in_window_mask, + _, + ) = generate_block_metadata( + self.dtype, + self.use_contiguous_pa, + slots, + block_tables_in_window, + self.bucketing_ctx, + slots_in_window, + block_bucket_size, + ) + meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window) + meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window) + meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window) + meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask) + + hpu_attention_meta = trim_attn_metadata(meta) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) kwargs = {} if htorch.utils.internal.is_lazy(): @@ -2014,16 +2137,25 @@ class FlashCausalLM(Model): ) kwargs = {} + batch_size = input_lengths.shape[0] + prompt_len = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) if htorch.utils.internal.is_lazy(): - batch_size = input_lengths.shape[0] - prompt_len = ( - input_ids.shape[0] // batch_size - if batch.prefilling - else batch.hpu_attn_meta.block_list.shape[0] - ) kwargs["bypass_hpu_graphs"] = not self.use_graphs( batch.prefilling, prompt_len, batch_size ) + if self.sliding_window is not None and batch.prefilling: + attn_mask = seqlen.make_sliding_window_bias( + input_lengths.tolist(), + self.sliding_window, + self.dtype, + prompt_len, + batch_size, + ) + seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -2303,6 +2435,7 @@ class FlashCausalLM(Model): self.use_contiguous_pa, self.bucketing_ctx, self.tokenizer.pad_token_id, + self.sliding_window, ) if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds): self.set_inputs_embeds(batch) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 54c35c58..0cd49d45 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -11,7 +11,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, - prepare_for_decode, + generate_block_metadata, ) from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE from loguru import logger @@ -21,6 +21,8 @@ from text_generation_server.layers.attention import ( Seqlen, trim_seqlen_metadata, _async_h2d_tensor_copy, + HPUPagedAttentionMetadata, + trim_attn_metadata, ) import habana_frameworks.torch as htorch import time @@ -749,33 +751,79 @@ class FlashVlmCausalLM(FlashCausalLM): ) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size - past_len = [] block_tables = [] slots = [] start_idx = 0 + slot_indices = [] # fetch the last blocked to warmup block num + for i in range(batch_size): block_array = list(range(start_idx, start_idx + blocks[i])) slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) block_tables.append(block_array) - past_len.append(blocks[i] * BLOCK_SIZE - 1) + slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1) start_idx += blocks[i] input_lengths = torch.ones(batch_size, dtype=torch.int32) seqlen = Seqlen( input_lengths=_async_h2d_tensor_copy(input_lengths), ) - - hpu_attention_meta = prepare_for_decode( - self.dtype, - self.use_contiguous_pa, - self.device, - slots, - block_tables, - batch_size, - bucketing_ctx=None, + block_list, block_groups, block_usage, _, block_bucket_size = ( + generate_block_metadata( + self.dtype, + self.use_contiguous_pa, + slots, + block_tables, + self.bucketing_ctx, + ) ) + meta = HPUPagedAttentionMetadata( + block_list=_async_h2d_tensor_copy(block_list), + block_groups=_async_h2d_tensor_copy(block_groups), + block_usage=_async_h2d_tensor_copy(block_usage), + block_mapping=None, + attn_bias=None, + ) + if self.sliding_window is not None: + block_tables_in_window = [] + for i, bt in enumerate(block_tables): + block_num_in_window = ( + self.sliding_window + BLOCK_SIZE - 1 + ) // BLOCK_SIZE + block_tables_in_window.append( + bt[max(0, blocks[i] - block_num_in_window) : blocks[i]] + ) + slots_in_window = [] + start_idx = 0 + for i, indice in enumerate(slot_indices): + mask = ( + indice - torch.arange(start_idx, indice + 1) + ) < self.sliding_window + slots_in_window.append(torch.arange(start_idx, indice + 1)[mask]) + start_idx += blocks[i] * BLOCK_SIZE + slots_in_window = torch.cat(slots_in_window, dim=0) + ( + block_list_in_window, + block_groups_in_window, + block_usage_in_window, + slots_in_window_mask, + _, + ) = generate_block_metadata( + self.dtype, + self.use_contiguous_pa, + slots, + block_tables_in_window, + self.bucketing_ctx, + slots_in_window, + block_bucket_size, + ) + meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window) + meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window) + meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window) + meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask) + + hpu_attention_meta = trim_attn_metadata(meta) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) inputs_embeds = self.get_inputs_embeds( input_ids=input_ids.to(self.device), @@ -1011,17 +1059,6 @@ class FlashVlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - kwargs = {} - if htorch.utils.internal.is_lazy(): - batch_size = input_lengths.shape[0] - seqlen = ( - input_ids.shape[0] // batch_size - if batch.prefilling - else batch.hpu_attn_meta.block_list.shape[0] - ) - kwargs["bypass_hpu_graphs"] = not self.use_graphs( - batch.prefilling, seqlen, batch_size - ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots @@ -1034,6 +1071,26 @@ class FlashVlmCausalLM(FlashCausalLM): seqlen = Seqlen( input_lengths=_async_h2d_tensor_copy(input_lengths), ) + kwargs = {} + batch_size = input_lengths.shape[0] + prompt_len = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, prompt_len, batch_size + ) + if self.sliding_window is not None: + attn_mask = seqlen.make_sliding_window_bias( + input_lengths.tolist(), + self.sliding_window, + self.dtype, + prompt_len, + batch_size, + ) + seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask) logits, speculative_logits = self.model.forward( inputs_embeds=inputs_embeds, position_ids=_async_h2d_tensor_copy(position_ids), diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index dbaccfa0..d266aad9 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -12,7 +12,7 @@ from transformers import ( PreTrainedTokenizerBase, ) from text_generation_server.models.flash_causal_lm import ( - prepare_for_decode, + generate_block_metadata, ) from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLMBatch, @@ -23,6 +23,8 @@ from text_generation_server.layers.attention import ( Seqlen, trim_seqlen_metadata, _async_h2d_tensor_copy, + HPUPagedAttentionMetadata, + trim_attn_metadata, ) import habana_frameworks.torch as htorch from loguru import logger @@ -224,7 +226,7 @@ def generate_cross_attention_states( cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling ): if cross_attention_states is None: - return None, None, None + return None, None indices_list = [] if prefilling: for i in image_indices: @@ -247,33 +249,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size - past_len = [] block_tables = [] slots = [] start_idx = 0 + slot_indices = [] # fetch the last blocked to warmup block num for i in range(batch_size): block_array = list(range(start_idx, start_idx + blocks[i])) slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) block_tables.append(block_array) - past_len.append(blocks[i] * BLOCK_SIZE - 1) + slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1) start_idx += blocks[i] input_lengths = torch.ones(batch_size, dtype=torch.int32) seqlen = Seqlen( input_lengths=_async_h2d_tensor_copy(input_lengths), ) - - hpu_attention_meta = prepare_for_decode( - self.dtype, - self.use_contiguous_pa, - self.device, - slots, - block_tables, - batch_size, - bucketing_ctx=None, + block_list, block_groups, block_usage, _, block_bucket_size = ( + generate_block_metadata( + self.dtype, + self.use_contiguous_pa, + slots, + block_tables, + self.bucketing_ctx, + ) ) + meta = HPUPagedAttentionMetadata( + block_list=_async_h2d_tensor_copy(block_list), + block_groups=_async_h2d_tensor_copy(block_groups), + block_usage=_async_h2d_tensor_copy(block_usage), + block_mapping=None, + attn_bias=None, + ) + + hpu_attention_meta = trim_attn_metadata(meta) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. image_indices = torch.tensor(batch.image_indices) image_indices = image_indices.repeat(batch_size)