mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-02 05:50:17 +00:00
[gaudi] Gemma3 sliding window support (#3280)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9f38d93051
commit
429dcd9c64
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from typing import Optional, List, Dict
|
||||
import collections
|
||||
import torch.nn.functional as F
|
||||
|
||||
_TYPE_CACHE = {}
|
||||
|
||||
@ -15,6 +16,12 @@ class HPUPagedAttentionMetadata:
|
||||
block_usage: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
attn_bias: Optional[torch.Tensor]
|
||||
slots_in_window_mask: Optional[torch.Tensor] = None
|
||||
block_list_in_window: Optional[torch.Tensor] = None
|
||||
block_mapping_in_window: Optional[torch.Tensor] = None
|
||||
block_usage_in_window: Optional[torch.Tensor] = None
|
||||
block_groups_in_window: Optional[torch.Tensor] = None
|
||||
attn_bias_in_window: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def subtuple(
|
||||
@ -67,6 +74,12 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||
"block_usage",
|
||||
"block_groups",
|
||||
"attn_bias",
|
||||
"slots_in_window_mask",
|
||||
"block_list_in_window",
|
||||
"block_mapping_in_window",
|
||||
"block_usage_in_window",
|
||||
"block_groups_in_window",
|
||||
"attn_bias_in_window",
|
||||
],
|
||||
)
|
||||
return attention_metadata
|
||||
@ -75,6 +88,7 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -86,6 +100,48 @@ class Seqlen:
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
|
||||
def make_sliding_window_bias(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
padded_input_len: Optional[int],
|
||||
padded_bs: Optional[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
if seq_len != 0:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = F.pad(
|
||||
mask,
|
||||
(
|
||||
padded_input_len - seq_len,
|
||||
0,
|
||||
padded_input_len - seq_len,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
value=0,
|
||||
)
|
||||
else:
|
||||
mask = torch.full(
|
||||
(1, padded_input_len, padded_input_len),
|
||||
dtype=dtype,
|
||||
fill_value=0,
|
||||
)
|
||||
attn_biases.append(mask)
|
||||
attn_biases = torch.stack(attn_biases, dim=0)
|
||||
return attn_biases.to(torch.bool)
|
||||
|
||||
|
||||
def _async_h2d_tensor_copy(source, device="hpu"):
|
||||
if source is None:
|
||||
@ -124,6 +180,7 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
||||
"TrimmedSeqlen",
|
||||
[
|
||||
"input_lengths",
|
||||
"attn_mask",
|
||||
],
|
||||
)
|
||||
return attention_metadata
|
||||
|
@ -94,13 +94,13 @@ def attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
attn_mask=seqlen.attn_mask if window_size_left != -1 else None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
is_causal=causal if window_size_left == -1 else False,
|
||||
scale=softmax_scale,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=seqlen.input_lengths,
|
||||
valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,
|
||||
padding_side="left",
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
@ -119,6 +119,15 @@ def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size)
|
||||
hpu_attention_meta = hpu_attention_meta._replace(
|
||||
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
|
||||
)
|
||||
if hpu_attention_meta.block_groups_in_window is not None:
|
||||
block_mapping = torch.nn.functional.one_hot(
|
||||
hpu_attention_meta.block_groups_in_window, num_classes=batch_size
|
||||
)
|
||||
attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())
|
||||
hpu_attention_meta = hpu_attention_meta._replace(
|
||||
attn_bias_in_window=attn_bias,
|
||||
block_mapping_in_window=block_mapping.to(dtype),
|
||||
)
|
||||
return hpu_attention_meta
|
||||
|
||||
|
||||
@ -132,6 +141,7 @@ def paged_attention(
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
window_size_left: int = -1,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||
@ -139,10 +149,26 @@ def paged_attention(
|
||||
query=query.view(batch_size, 1, head_num * head_size),
|
||||
key_cache=kv_cache.key,
|
||||
value_cache=kv_cache.value,
|
||||
block_list=hpu_attention_meta.block_list,
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
block_list=(
|
||||
hpu_attention_meta.block_list
|
||||
if window_size_left == -1
|
||||
else hpu_attention_meta.block_list_in_window
|
||||
),
|
||||
block_mapping=(
|
||||
hpu_attention_meta.block_mapping
|
||||
if window_size_left == -1
|
||||
else hpu_attention_meta.block_mapping_in_window
|
||||
),
|
||||
block_bias=(
|
||||
hpu_attention_meta.attn_bias
|
||||
if window_size_left == -1
|
||||
else hpu_attention_meta.attn_bias_in_window
|
||||
),
|
||||
block_groups=(
|
||||
hpu_attention_meta.block_groups
|
||||
if window_size_left == -1
|
||||
else hpu_attention_meta.block_groups_in_window
|
||||
),
|
||||
block_size=BLOCK_SIZE,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
|
@ -288,6 +288,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
softcap=self.softcap,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -135,9 +135,6 @@ class FlashGemma3Attention(torch.nn.Module):
|
||||
self.causal = causal
|
||||
if is_sliding:
|
||||
self.window_size = config.sliding_window
|
||||
# TODO: remove this hack to support local sliding window
|
||||
config = copy.deepcopy(config)
|
||||
config.rope_scaling = dict(rope_type="default")
|
||||
self.rotary_emb = local_rotary_emb
|
||||
else:
|
||||
self.window_size = -1
|
||||
@ -267,6 +264,7 @@ class FlashGemma3Attention(torch.nn.Module):
|
||||
softcap=self.softcap,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
@ -425,8 +423,10 @@ class FlashGemma3Model(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
local_config = copy.deepcopy(config)
|
||||
local_config.rope_scaling = dict(rope_type="default")
|
||||
local_rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
config=local_config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_local_base_freq,
|
||||
device=weights.device,
|
||||
|
@ -224,6 +224,7 @@ class MistralAttention(torch.nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -62,7 +62,9 @@ class Qwen2Attention(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
config.sliding_window
|
||||
if config.use_sliding_window and config.sliding_window is not None
|
||||
else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -150,6 +152,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -167,6 +167,7 @@ class Qwen3Attention(nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
|
@ -190,6 +190,7 @@ class Qwen3MoeAttention(nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
|
@ -280,6 +280,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -81,8 +81,14 @@ from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def prepare_for_decode(
|
||||
dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx
|
||||
def generate_block_metadata(
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
slots,
|
||||
block_tables,
|
||||
bucketing_ctx,
|
||||
slots_in_window=None,
|
||||
block_bucket_size=None,
|
||||
):
|
||||
# Prepare values if we need to continue decoding
|
||||
# need for HPUPagedAttentionMetadata preparation
|
||||
@ -112,11 +118,12 @@ def prepare_for_decode(
|
||||
assert len(block_list) == len(block_groups)
|
||||
assert len(block_list) == len(block_usage)
|
||||
if use_contiguous_pa:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
if bucketing_ctx is not None:
|
||||
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size
|
||||
)
|
||||
if block_bucket_size is None:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
if bucketing_ctx is not None:
|
||||
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size
|
||||
)
|
||||
indices: List[Any]
|
||||
indices = [None] * block_bucket_size
|
||||
for i, bid in enumerate(block_list):
|
||||
@ -125,30 +132,38 @@ def prepare_for_decode(
|
||||
block_groups = gather_list(block_groups, indices, -1)
|
||||
block_usage = gather_list(block_usage, indices, 1)
|
||||
else:
|
||||
block_bucket_size = len(block_list)
|
||||
if bucketing_ctx is not None:
|
||||
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size
|
||||
)
|
||||
if block_bucket_size is None:
|
||||
block_bucket_size = len(block_list)
|
||||
if bucketing_ctx is not None:
|
||||
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size
|
||||
)
|
||||
block_list = pad_list(block_list, block_bucket_size, 0)
|
||||
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||
slots_in_window_mask = None
|
||||
if slots_in_window is not None:
|
||||
slot_list = [
|
||||
block_id * BLOCK_SIZE + slot_idx
|
||||
for block_id in block_list
|
||||
for slot_idx in range(BLOCK_SIZE)
|
||||
]
|
||||
slot_list = torch.tensor(slot_list, dtype=torch.int64)
|
||||
slot_list = slot_list.view(-1, BLOCK_SIZE)
|
||||
slots_in_window_mask = torch.isin(slot_list, slots_in_window)
|
||||
for i in range(slots_in_window_mask.shape[0]):
|
||||
if not slots_in_window_mask[i].any():
|
||||
slots_in_window_mask[i, 0] = True
|
||||
|
||||
block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
|
||||
block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
|
||||
block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
|
||||
block_list_device = _async_h2d_tensor_copy(block_list)
|
||||
block_groups_device = _async_h2d_tensor_copy(block_groups)
|
||||
block_usage_device = _async_h2d_tensor_copy(block_usage)
|
||||
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
block_list=block_list_device,
|
||||
block_groups=block_groups_device,
|
||||
block_usage=block_usage_device,
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
return (
|
||||
block_list,
|
||||
block_groups,
|
||||
block_usage,
|
||||
slots_in_window_mask,
|
||||
block_bucket_size,
|
||||
)
|
||||
|
||||
|
||||
@ -962,7 +977,9 @@ class FlashCausalLMBatch(Batch):
|
||||
valid_indices=None,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id):
|
||||
def prepare_for_decode(
|
||||
self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id, sliding_window
|
||||
):
|
||||
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
|
||||
block_tables = []
|
||||
for i, bt in enumerate(self.block_tables):
|
||||
@ -975,15 +992,65 @@ class FlashCausalLMBatch(Batch):
|
||||
padded_bs = self.input_ids.shape[0]
|
||||
slots = self.slots[self.slot_indices]
|
||||
|
||||
self.hpu_attn_meta = prepare_for_decode(
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
"hpu",
|
||||
slots,
|
||||
block_tables,
|
||||
padded_bs,
|
||||
bucketing_ctx,
|
||||
block_list, block_groups, block_usage, _, block_bucket_size = (
|
||||
generate_block_metadata(
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
slots,
|
||||
block_tables,
|
||||
bucketing_ctx,
|
||||
)
|
||||
)
|
||||
meta = HPUPagedAttentionMetadata(
|
||||
block_list=_async_h2d_tensor_copy(block_list),
|
||||
block_groups=_async_h2d_tensor_copy(block_groups),
|
||||
block_usage=_async_h2d_tensor_copy(block_usage),
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
if sliding_window is not None:
|
||||
block_tables_in_window = []
|
||||
for i, bt in enumerate(self.block_tables):
|
||||
block_num_in_window = (
|
||||
sliding_window + 2 * BLOCK_SIZE - 2 - slots[i] % BLOCK_SIZE
|
||||
) // BLOCK_SIZE
|
||||
block_tables_in_window.append(
|
||||
bt[max(0, block_num[i] - block_num_in_window) : block_num[i]]
|
||||
)
|
||||
slots_in_window = []
|
||||
for i, indice in enumerate(self.slot_indices):
|
||||
start_idx = indice - self.cache_lengths[i]
|
||||
mask = (
|
||||
indice
|
||||
- torch.arange(
|
||||
start_idx,
|
||||
indice + 1,
|
||||
device=self.slots.device,
|
||||
)
|
||||
) < sliding_window
|
||||
slots_in_window.append(self.slots[start_idx : indice + 1][mask])
|
||||
slots_in_window = torch.cat(slots_in_window, dim=0)
|
||||
(
|
||||
block_list_in_window,
|
||||
block_groups_in_window,
|
||||
block_usage_in_window,
|
||||
slots_in_window_mask,
|
||||
_,
|
||||
) = generate_block_metadata(
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
slots,
|
||||
block_tables_in_window,
|
||||
bucketing_ctx,
|
||||
slots_in_window,
|
||||
block_bucket_size,
|
||||
)
|
||||
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
|
||||
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
|
||||
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
|
||||
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
|
||||
|
||||
self.hpu_attn_meta = trim_attn_metadata(meta)
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
|
||||
)
|
||||
@ -1443,6 +1510,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
if getattr(config, "sliding_window", None) is None:
|
||||
config.sliding_window = None
|
||||
if getattr(config, "use_sliding_window", True) is False:
|
||||
config.sliding_window = None
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||
@ -1865,6 +1934,15 @@ class FlashCausalLM(Model):
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
True, prompt_len, batch_size
|
||||
)
|
||||
if self.sliding_window is not None:
|
||||
attn_mask = seqlen.make_sliding_window_bias(
|
||||
input_lengths.tolist(),
|
||||
self.sliding_window,
|
||||
self.dtype,
|
||||
prompt_len,
|
||||
batch_size,
|
||||
)
|
||||
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -1885,17 +1963,17 @@ class FlashCausalLM(Model):
|
||||
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
|
||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||
blocks[0] += block_num % batch_size
|
||||
past_len = []
|
||||
block_tables = []
|
||||
slots = []
|
||||
start_idx = 0
|
||||
slot_indices = []
|
||||
|
||||
# fetch the last blocked to warmup block num
|
||||
for i in range(batch_size):
|
||||
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32)
|
||||
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
|
||||
@ -1904,16 +1982,61 @@ class FlashCausalLM(Model):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.device,
|
||||
slots,
|
||||
block_tables,
|
||||
batch_size,
|
||||
bucketing_ctx=None,
|
||||
block_list, block_groups, block_usage, _, block_bucket_size = (
|
||||
generate_block_metadata(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
slots,
|
||||
block_tables,
|
||||
self.bucketing_ctx,
|
||||
)
|
||||
)
|
||||
meta = HPUPagedAttentionMetadata(
|
||||
block_list=_async_h2d_tensor_copy(block_list),
|
||||
block_groups=_async_h2d_tensor_copy(block_groups),
|
||||
block_usage=_async_h2d_tensor_copy(block_usage),
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
if self.sliding_window is not None:
|
||||
block_tables_in_window = []
|
||||
for i, bt in enumerate(block_tables):
|
||||
block_num_in_window = (
|
||||
self.sliding_window + BLOCK_SIZE - 1
|
||||
) // BLOCK_SIZE
|
||||
block_tables_in_window.append(
|
||||
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
|
||||
)
|
||||
slots_in_window = []
|
||||
start_idx = 0
|
||||
for i, indice in enumerate(slot_indices):
|
||||
mask = (
|
||||
indice - torch.arange(start_idx, indice + 1)
|
||||
) < self.sliding_window
|
||||
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
|
||||
start_idx += blocks[i] * BLOCK_SIZE
|
||||
slots_in_window = torch.cat(slots_in_window, dim=0)
|
||||
(
|
||||
block_list_in_window,
|
||||
block_groups_in_window,
|
||||
block_usage_in_window,
|
||||
slots_in_window_mask,
|
||||
_,
|
||||
) = generate_block_metadata(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
slots,
|
||||
block_tables_in_window,
|
||||
self.bucketing_ctx,
|
||||
slots_in_window,
|
||||
block_bucket_size,
|
||||
)
|
||||
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
|
||||
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
|
||||
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
|
||||
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
|
||||
|
||||
hpu_attention_meta = trim_attn_metadata(meta)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
@ -2014,16 +2137,25 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
batch_size = input_lengths.shape[0]
|
||||
prompt_len = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
if htorch.utils.internal.is_lazy():
|
||||
batch_size = input_lengths.shape[0]
|
||||
prompt_len = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, prompt_len, batch_size
|
||||
)
|
||||
if self.sliding_window is not None and batch.prefilling:
|
||||
attn_mask = seqlen.make_sliding_window_bias(
|
||||
input_lengths.tolist(),
|
||||
self.sliding_window,
|
||||
self.dtype,
|
||||
prompt_len,
|
||||
batch_size,
|
||||
)
|
||||
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -2303,6 +2435,7 @@ class FlashCausalLM(Model):
|
||||
self.use_contiguous_pa,
|
||||
self.bucketing_ctx,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.sliding_window,
|
||||
)
|
||||
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
|
||||
self.set_inputs_embeds(batch)
|
||||
|
@ -11,7 +11,7 @@ from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
prepare_for_decode,
|
||||
generate_block_metadata,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
|
||||
from loguru import logger
|
||||
@ -21,6 +21,8 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
trim_seqlen_metadata,
|
||||
_async_h2d_tensor_copy,
|
||||
HPUPagedAttentionMetadata,
|
||||
trim_attn_metadata,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
import time
|
||||
@ -749,33 +751,79 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||
blocks[0] += block_num % batch_size
|
||||
past_len = []
|
||||
block_tables = []
|
||||
slots = []
|
||||
start_idx = 0
|
||||
slot_indices = []
|
||||
|
||||
# fetch the last blocked to warmup block num
|
||||
|
||||
for i in range(batch_size):
|
||||
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32)
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.device,
|
||||
slots,
|
||||
block_tables,
|
||||
batch_size,
|
||||
bucketing_ctx=None,
|
||||
block_list, block_groups, block_usage, _, block_bucket_size = (
|
||||
generate_block_metadata(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
slots,
|
||||
block_tables,
|
||||
self.bucketing_ctx,
|
||||
)
|
||||
)
|
||||
meta = HPUPagedAttentionMetadata(
|
||||
block_list=_async_h2d_tensor_copy(block_list),
|
||||
block_groups=_async_h2d_tensor_copy(block_groups),
|
||||
block_usage=_async_h2d_tensor_copy(block_usage),
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
if self.sliding_window is not None:
|
||||
block_tables_in_window = []
|
||||
for i, bt in enumerate(block_tables):
|
||||
block_num_in_window = (
|
||||
self.sliding_window + BLOCK_SIZE - 1
|
||||
) // BLOCK_SIZE
|
||||
block_tables_in_window.append(
|
||||
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
|
||||
)
|
||||
slots_in_window = []
|
||||
start_idx = 0
|
||||
for i, indice in enumerate(slot_indices):
|
||||
mask = (
|
||||
indice - torch.arange(start_idx, indice + 1)
|
||||
) < self.sliding_window
|
||||
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
|
||||
start_idx += blocks[i] * BLOCK_SIZE
|
||||
slots_in_window = torch.cat(slots_in_window, dim=0)
|
||||
(
|
||||
block_list_in_window,
|
||||
block_groups_in_window,
|
||||
block_usage_in_window,
|
||||
slots_in_window_mask,
|
||||
_,
|
||||
) = generate_block_metadata(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
slots,
|
||||
block_tables_in_window,
|
||||
self.bucketing_ctx,
|
||||
slots_in_window,
|
||||
block_bucket_size,
|
||||
)
|
||||
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
|
||||
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
|
||||
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
|
||||
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
|
||||
|
||||
hpu_attention_meta = trim_attn_metadata(meta)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
inputs_embeds = self.get_inputs_embeds(
|
||||
input_ids=input_ids.to(self.device),
|
||||
@ -1011,17 +1059,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
batch_size = input_lengths.shape[0]
|
||||
seqlen = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
@ -1034,6 +1071,26 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
kwargs = {}
|
||||
batch_size = input_lengths.shape[0]
|
||||
prompt_len = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, prompt_len, batch_size
|
||||
)
|
||||
if self.sliding_window is not None:
|
||||
attn_mask = seqlen.make_sliding_window_bias(
|
||||
input_lengths.tolist(),
|
||||
self.sliding_window,
|
||||
self.dtype,
|
||||
prompt_len,
|
||||
batch_size,
|
||||
)
|
||||
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
|
@ -12,7 +12,7 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
prepare_for_decode,
|
||||
generate_block_metadata,
|
||||
)
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
@ -23,6 +23,8 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
trim_seqlen_metadata,
|
||||
_async_h2d_tensor_copy,
|
||||
HPUPagedAttentionMetadata,
|
||||
trim_attn_metadata,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
from loguru import logger
|
||||
@ -224,7 +226,7 @@ def generate_cross_attention_states(
|
||||
cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
|
||||
):
|
||||
if cross_attention_states is None:
|
||||
return None, None, None
|
||||
return None, None
|
||||
indices_list = []
|
||||
if prefilling:
|
||||
for i in image_indices:
|
||||
@ -247,33 +249,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
|
||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||
blocks[0] += block_num % batch_size
|
||||
past_len = []
|
||||
block_tables = []
|
||||
slots = []
|
||||
start_idx = 0
|
||||
slot_indices = []
|
||||
|
||||
# fetch the last blocked to warmup block num
|
||||
for i in range(batch_size):
|
||||
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32)
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.device,
|
||||
slots,
|
||||
block_tables,
|
||||
batch_size,
|
||||
bucketing_ctx=None,
|
||||
block_list, block_groups, block_usage, _, block_bucket_size = (
|
||||
generate_block_metadata(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
slots,
|
||||
block_tables,
|
||||
self.bucketing_ctx,
|
||||
)
|
||||
)
|
||||
meta = HPUPagedAttentionMetadata(
|
||||
block_list=_async_h2d_tensor_copy(block_list),
|
||||
block_groups=_async_h2d_tensor_copy(block_groups),
|
||||
block_usage=_async_h2d_tensor_copy(block_usage),
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
|
||||
hpu_attention_meta = trim_attn_metadata(meta)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
image_indices = torch.tensor(batch.image_indices)
|
||||
image_indices = image_indices.repeat(batch_size)
|
||||
|
Loading…
Reference in New Issue
Block a user