mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
qwen2 sliding window fix, mllama does not contain sliding window
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
800281113f
commit
99323542f0
@ -62,7 +62,9 @@ class Qwen2Attention(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
config.sliding_window
|
||||
if config.use_sliding_window and config.sliding_window is not None
|
||||
else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
|
@ -1510,6 +1510,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
if getattr(config, "sliding_window", None) is None:
|
||||
config.sliding_window = None
|
||||
if getattr(config, "use_sliding_window", True) is False:
|
||||
config.sliding_window = None
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||
|
@ -1059,17 +1059,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
batch_size = input_lengths.shape[0]
|
||||
seqlen = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
@ -1082,6 +1071,26 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
kwargs = {}
|
||||
batch_size = input_lengths.shape[0]
|
||||
prompt_len = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, prompt_len, batch_size
|
||||
)
|
||||
if self.sliding_window is not None:
|
||||
attn_mask = seqlen.make_sliding_window_bias(
|
||||
input_lengths.tolist(),
|
||||
self.sliding_window,
|
||||
self.dtype,
|
||||
prompt_len,
|
||||
batch_size,
|
||||
)
|
||||
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
|
@ -282,43 +282,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
if self.sliding_window is not None:
|
||||
block_tables_in_window = []
|
||||
for i, bt in enumerate(block_tables):
|
||||
block_num_in_window = (
|
||||
self.sliding_window + BLOCK_SIZE - 1
|
||||
) // BLOCK_SIZE
|
||||
block_tables_in_window.append(
|
||||
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
|
||||
)
|
||||
slots_in_window = []
|
||||
start_idx = 0
|
||||
for i, indice in enumerate(slot_indices):
|
||||
mask = (
|
||||
indice - torch.arange(start_idx, indice + 1)
|
||||
) < self.sliding_window
|
||||
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
|
||||
start_idx += blocks[i] * BLOCK_SIZE
|
||||
slots_in_window = torch.cat(slots_in_window, dim=0)
|
||||
(
|
||||
block_list_in_window,
|
||||
block_groups_in_window,
|
||||
block_usage_in_window,
|
||||
slots_in_window_mask,
|
||||
_,
|
||||
) = generate_block_metadata(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
slots,
|
||||
block_tables_in_window,
|
||||
self.bucketing_ctx,
|
||||
slots_in_window,
|
||||
block_bucket_size,
|
||||
)
|
||||
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
|
||||
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
|
||||
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
|
||||
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
|
||||
|
||||
hpu_attention_meta = trim_attn_metadata(meta)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
|
Loading…
Reference in New Issue
Block a user