qwen2 sliding window fix, mllama does not contain sliding window

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-26 22:53:24 -07:00
parent 800281113f
commit 99323542f0
4 changed files with 25 additions and 49 deletions

View File

@ -62,7 +62,9 @@ class Qwen2Attention(torch.nn.Module):
): ):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else -1 config.sliding_window
if config.use_sliding_window and config.sliding_window is not None
else -1
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size

View File

@ -1510,6 +1510,8 @@ class FlashCausalLM(Model):
if getattr(config, "sliding_window", None) is None: if getattr(config, "sliding_window", None) is None:
config.sliding_window = None config.sliding_window = None
if getattr(config, "use_sliding_window", True) is False:
config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size() self.num_heads = config.num_attention_heads // self.process_group.size()

View File

@ -1059,17 +1059,6 @@ class FlashVlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
kwargs = {}
if htorch.utils.internal.is_lazy():
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
@ -1082,6 +1071,26 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
kwargs = {}
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
if self.sliding_window is not None:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),

View File

@ -282,43 +282,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
block_mapping=None, block_mapping=None,
attn_bias=None, attn_bias=None,
) )
if self.sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(block_tables):
block_num_in_window = (
self.sliding_window + BLOCK_SIZE - 1
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
)
slots_in_window = []
start_idx = 0
for i, indice in enumerate(slot_indices):
mask = (
indice - torch.arange(start_idx, indice + 1)
) < self.sliding_window
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
start_idx += blocks[i] * BLOCK_SIZE
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables_in_window,
self.bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
hpu_attention_meta = trim_attn_metadata(meta) hpu_attention_meta = trim_attn_metadata(meta)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.