[gaudi] Gemma3 sliding window support (#3280)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-07-01 16:06:01 +08:00 committed by GitHub
parent 9f38d93051
commit 429dcd9c64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 389 additions and 98 deletions

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
import torch
from typing import Optional, List, Dict
import collections
import torch.nn.functional as F
_TYPE_CACHE = {}
@ -15,6 +16,12 @@ class HPUPagedAttentionMetadata:
block_usage: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
slots_in_window_mask: Optional[torch.Tensor] = None
block_list_in_window: Optional[torch.Tensor] = None
block_mapping_in_window: Optional[torch.Tensor] = None
block_usage_in_window: Optional[torch.Tensor] = None
block_groups_in_window: Optional[torch.Tensor] = None
attn_bias_in_window: Optional[torch.Tensor] = None
def subtuple(
@ -67,6 +74,12 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
"block_usage",
"block_groups",
"attn_bias",
"slots_in_window_mask",
"block_list_in_window",
"block_mapping_in_window",
"block_usage_in_window",
"block_groups_in_window",
"attn_bias_in_window",
],
)
return attention_metadata
@ -75,6 +88,7 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
def __init__(
self,
@ -86,6 +100,48 @@ class Seqlen:
# Flash decoding doesn't need to clamp
return self
def make_sliding_window_bias(
self,
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
padded_input_len: Optional[int],
padded_bs: Optional[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
if seq_len != 0:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = F.pad(
mask,
(
padded_input_len - seq_len,
0,
padded_input_len - seq_len,
0,
0,
0,
),
value=0,
)
else:
mask = torch.full(
(1, padded_input_len, padded_input_len),
dtype=dtype,
fill_value=0,
)
attn_biases.append(mask)
attn_biases = torch.stack(attn_biases, dim=0)
return attn_biases.to(torch.bool)
def _async_h2d_tensor_copy(source, device="hpu"):
if source is None:
@ -124,6 +180,7 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object:
"TrimmedSeqlen",
[
"input_lengths",
"attn_mask",
],
)
return attention_metadata

View File

@ -94,13 +94,13 @@ def attention(
query,
key,
value,
attn_mask=None,
attn_mask=seqlen.attn_mask if window_size_left != -1 else None,
dropout_p=0.0,
is_causal=causal,
is_causal=causal if window_size_left == -1 else False,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths,
valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
@ -119,6 +119,15 @@ def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
if hpu_attention_meta.block_groups_in_window is not None:
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups_in_window, num_classes=batch_size
)
attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias_in_window=attn_bias,
block_mapping_in_window=block_mapping.to(dtype),
)
return hpu_attention_meta
@ -132,6 +141,7 @@ def paged_attention(
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
window_size_left: int = -1,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
@ -139,10 +149,26 @@ def paged_attention(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_list=(
hpu_attention_meta.block_list
if window_size_left == -1
else hpu_attention_meta.block_list_in_window
),
block_mapping=(
hpu_attention_meta.block_mapping
if window_size_left == -1
else hpu_attention_meta.block_mapping_in_window
),
block_bias=(
hpu_attention_meta.attn_bias
if window_size_left == -1
else hpu_attention_meta.attn_bias_in_window
),
block_groups=(
hpu_attention_meta.block_groups
if window_size_left == -1
else hpu_attention_meta.block_groups_in_window
),
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),

View File

@ -288,6 +288,7 @@ class FlashGemma2Attention(torch.nn.Module):
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.window_size,
)
return self.o_proj(

View File

@ -135,9 +135,6 @@ class FlashGemma3Attention(torch.nn.Module):
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
# TODO: remove this hack to support local sliding window
config = copy.deepcopy(config)
config.rope_scaling = dict(rope_type="default")
self.rotary_emb = local_rotary_emb
else:
self.window_size = -1
@ -267,6 +264,7 @@ class FlashGemma3Attention(torch.nn.Module):
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.window_size,
)
return self.o_proj(
@ -425,8 +423,10 @@ class FlashGemma3Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
local_config = copy.deepcopy(config)
local_config.rope_scaling = dict(rope_type="default")
local_rotary_emb = PositionRotaryEmbedding.static(
config=config,
config=local_config,
dim=config.head_dim,
base=config.rope_local_base_freq,
device=weights.device,

View File

@ -224,6 +224,7 @@ class MistralAttention(torch.nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
return self.o_proj(

View File

@ -62,7 +62,9 @@ class Qwen2Attention(torch.nn.Module):
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
config.sliding_window
if config.use_sliding_window and config.sliding_window is not None
else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
@ -150,6 +152,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -167,6 +167,7 @@ class Qwen3Attention(nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

View File

@ -190,6 +190,7 @@ class Qwen3MoeAttention(nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

View File

@ -280,6 +280,7 @@ class Starcoder2Attention(torch.nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
return self.o_proj(

View File

@ -81,8 +81,14 @@ from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
def prepare_for_decode(
dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx
def generate_block_metadata(
dtype,
use_contiguous_pa,
slots,
block_tables,
bucketing_ctx,
slots_in_window=None,
block_bucket_size=None,
):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
@ -112,6 +118,7 @@ def prepare_for_decode(
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
if use_contiguous_pa:
if block_bucket_size is None:
block_bucket_size = max(max(block_list) + 1, len(block_list))
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
@ -125,6 +132,7 @@ def prepare_for_decode(
block_groups = gather_list(block_groups, indices, -1)
block_usage = gather_list(block_usage, indices, 1)
else:
if block_bucket_size is None:
block_bucket_size = len(block_list)
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
@ -133,22 +141,29 @@ def prepare_for_decode(
block_list = pad_list(block_list, block_bucket_size, 0)
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
slots_in_window_mask = None
if slots_in_window is not None:
slot_list = [
block_id * BLOCK_SIZE + slot_idx
for block_id in block_list
for slot_idx in range(BLOCK_SIZE)
]
slot_list = torch.tensor(slot_list, dtype=torch.int64)
slot_list = slot_list.view(-1, BLOCK_SIZE)
slots_in_window_mask = torch.isin(slot_list, slots_in_window)
for i in range(slots_in_window_mask.shape[0]):
if not slots_in_window_mask[i].any():
slots_in_window_mask[i, 0] = True
block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
block_list_device = _async_h2d_tensor_copy(block_list)
block_groups_device = _async_h2d_tensor_copy(block_groups)
block_usage_device = _async_h2d_tensor_copy(block_usage)
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list_device,
block_groups=block_groups_device,
block_usage=block_usage_device,
block_mapping=None,
attn_bias=None,
)
return (
block_list,
block_groups,
block_usage,
slots_in_window_mask,
block_bucket_size,
)
@ -962,7 +977,9 @@ class FlashCausalLMBatch(Batch):
valid_indices=None,
)
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id):
def prepare_for_decode(
self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id, sliding_window
):
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
block_tables = []
for i, bt in enumerate(self.block_tables):
@ -975,15 +992,65 @@ class FlashCausalLMBatch(Batch):
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
self.hpu_attn_meta = prepare_for_decode(
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
dtype,
use_contiguous_pa,
"hpu",
slots,
block_tables,
padded_bs,
bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
if sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(self.block_tables):
block_num_in_window = (
sliding_window + 2 * BLOCK_SIZE - 2 - slots[i] % BLOCK_SIZE
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, block_num[i] - block_num_in_window) : block_num[i]]
)
slots_in_window = []
for i, indice in enumerate(self.slot_indices):
start_idx = indice - self.cache_lengths[i]
mask = (
indice
- torch.arange(
start_idx,
indice + 1,
device=self.slots.device,
)
) < sliding_window
slots_in_window.append(self.slots[start_idx : indice + 1][mask])
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
dtype,
use_contiguous_pa,
slots,
block_tables_in_window,
bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
self.hpu_attn_meta = trim_attn_metadata(meta)
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
)
@ -1443,6 +1510,8 @@ class FlashCausalLM(Model):
if getattr(config, "sliding_window", None) is None:
config.sliding_window = None
if getattr(config, "use_sliding_window", True) is False:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size()
@ -1865,6 +1934,15 @@ class FlashCausalLM(Model):
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
if self.sliding_window is not None:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -1885,17 +1963,17 @@ class FlashCausalLM(Model):
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
slot_indices = []
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
@ -1904,16 +1982,61 @@ class FlashCausalLM(Model):
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
self.bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
if self.sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(block_tables):
block_num_in_window = (
self.sliding_window + BLOCK_SIZE - 1
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
)
slots_in_window = []
start_idx = 0
for i, indice in enumerate(slot_indices):
mask = (
indice - torch.arange(start_idx, indice + 1)
) < self.sliding_window
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
start_idx += blocks[i] * BLOCK_SIZE
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables_in_window,
self.bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
hpu_attention_meta = trim_attn_metadata(meta)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
@ -2014,16 +2137,25 @@ class FlashCausalLM(Model):
)
kwargs = {}
if htorch.utils.internal.is_lazy():
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
if self.sliding_window is not None and batch.prefilling:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -2303,6 +2435,7 @@ class FlashCausalLM(Model):
self.use_contiguous_pa,
self.bucketing_ctx,
self.tokenizer.pad_token_id,
self.sliding_window,
)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)

View File

@ -11,7 +11,7 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
prepare_for_decode,
generate_block_metadata,
)
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
from loguru import logger
@ -21,6 +21,8 @@ from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
HPUPagedAttentionMetadata,
trim_attn_metadata,
)
import habana_frameworks.torch as htorch
import time
@ -749,33 +751,79 @@ class FlashVlmCausalLM(FlashCausalLM):
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
slot_indices = []
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
self.bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
if self.sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(block_tables):
block_num_in_window = (
self.sliding_window + BLOCK_SIZE - 1
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
)
slots_in_window = []
start_idx = 0
for i, indice in enumerate(slot_indices):
mask = (
indice - torch.arange(start_idx, indice + 1)
) < self.sliding_window
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
start_idx += blocks[i] * BLOCK_SIZE
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables_in_window,
self.bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
hpu_attention_meta = trim_attn_metadata(meta)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
inputs_embeds = self.get_inputs_embeds(
input_ids=input_ids.to(self.device),
@ -1011,17 +1059,6 @@ class FlashVlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
kwargs = {}
if htorch.utils.internal.is_lazy():
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
@ -1034,6 +1071,26 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
if self.sliding_window is not None:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
logits, speculative_logits = self.model.forward(
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),

View File

@ -12,7 +12,7 @@ from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.flash_causal_lm import (
prepare_for_decode,
generate_block_metadata,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
@ -23,6 +23,8 @@ from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
HPUPagedAttentionMetadata,
trim_attn_metadata,
)
import habana_frameworks.torch as htorch
from loguru import logger
@ -224,7 +226,7 @@ def generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
):
if cross_attention_states is None:
return None, None, None
return None, None
indices_list = []
if prefilling:
for i in image_indices:
@ -247,33 +249,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
slot_indices = []
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
self.bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
hpu_attention_meta = trim_attn_metadata(meta)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)