(fix) sliding window attention

This commit is contained in:
Mohit Sharma 2025-03-13 19:30:39 +00:00
parent f91434e99b
commit ff82f0f84c
16 changed files with 54 additions and 107 deletions

View File

@ -14,8 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) - [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
- [Gemma3](https://huggingface.co/collections/google/gemma-3) - [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3) - [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)

View File

@ -699,7 +699,7 @@ fn image_tokens(
// TODO: prefer using the config to determine the number of features // TODO: prefer using the config to determine the number of features
let num_mm_soft_tokens_per_image = 256; let num_mm_soft_tokens_per_image = 256;
format!( format!(
"\n\n<start_of_image>{:?}<end_of_image>\n\n", "\n\n<start_of_image>{}<end_of_image>\n\n",
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image) "<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
) )
} }

View File

@ -205,7 +205,6 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
if key not in target_to_layer: if key not in target_to_layer:

View File

@ -38,6 +38,7 @@ def paged_attention(
*, *,
kv_scales: KVScales, kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights # Copyright 2023 The vLLM team. All rights
@ -79,12 +80,15 @@ def paged_attention(
sm_scale=softmax_scale, sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
window_size_left=window_size_left,
) )
elif ATTENTION == "flashdecoding": elif ATTENTION == "flashdecoding":
max_q = 1 max_q = 1
max_k = max_s max_k = max_s
import flash_attn_2_cuda import flash_attn_2_cuda
window_size_right = -1 if window_size_left == -1 else 0
# TODO fixme when flash contains the fix. # TODO fixme when flash contains the fix.
# Number of splits is not correctly handled # Number of splits is not correctly handled
# by the current path # by the current path
@ -109,8 +113,8 @@ def paged_attention(
softmax_scale, softmax_scale,
False, # zero_tensors False, # zero_tensors
True, # causal True, # causal
-1, # Window_left window_size_left, # Window_left
-1, # Window right window_size_right, # Window right
softcap, softcap,
False, # return softmax False, # return softmax
None, # generator None, # generator
@ -253,6 +257,7 @@ def attention(
sm_scale=softmax_scale, sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
window_size_left=window_size_left,
) )
# If we are using flashdecoding or paged, we always use flash-attn for # If we are using flashdecoding or paged, we always use flash-attn for

View File

@ -52,7 +52,6 @@ def use_prefill_with_paged_kv_state(
page_size: int, page_size: int,
kv_dtype: torch.dtype, kv_dtype: torch.dtype,
q_dtype: torch.dtype, q_dtype: torch.dtype,
window_left: int,
): ):
""" """
Context manager to set the active flashinfer prefill state to the given Context manager to set the active flashinfer prefill state to the given
@ -95,7 +94,6 @@ def use_prefill_with_paged_kv_state(
kv_data_type=kv_dtype, kv_data_type=kv_dtype,
q_data_type=q_dtype, q_data_type=q_dtype,
page_size=page_size, page_size=page_size,
window_left=-1 if window_left is None else window_left,
) )
yield yield
finally: finally:
@ -172,7 +170,6 @@ def use_decode_state(
page_size: int, page_size: int,
kv_cache_dtype: torch.dtype, kv_cache_dtype: torch.dtype,
q_dtype: torch.dtype, q_dtype: torch.dtype,
window_left: int,
): ):
""" """
Context manager to set the active flashinfer decoding state to the given Context manager to set the active flashinfer decoding state to the given
@ -209,7 +206,6 @@ def use_decode_state(
page_size=page_size, page_size=page_size,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=q_dtype, q_data_type=q_dtype,
window_left=-1 if window_left is None else window_left,
) )
yield yield
finally: finally:

View File

@ -272,12 +272,12 @@ class ModelType(enum.Enum):
GEMMA3 = { GEMMA3 = {
"type": "gemma3", "type": "gemma3",
"name": "Gemma3", "name": "Gemma3",
"url": "https://huggingface.co/collections/google/gemma-3", "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d-release-67c6c6f89c4f76621268bb6d",
} }
GEMMA3_TEXT = { GEMMA3_TEXT = {
"type": "gemma3_text", "type": "gemma3_text",
"name": "Gemma3 Text", "name": "Gemma3 Text",
"url": "https://huggingface.co/collections/google/gemma-3", "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d-release-67c6c6f89c4f76621268bb6d",
} }
COHERE = { COHERE = {
"type": "cohere", "type": "cohere",

View File

@ -287,6 +287,7 @@ class FlashGemma2Attention(torch.nn.Module):
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.window_size,
) )
return self.o_proj( return self.o_proj(

View File

@ -281,22 +281,12 @@ class FlashGemma3Attention(torch.nn.Module):
padded_query = padded_query.transpose(1, 2).contiguous() padded_query = padded_query.transpose(1, 2).contiguous()
padded_key = padded_key.transpose(1, 2).contiguous() padded_key = padded_key.transpose(1, 2).contiguous()
padded_value = padded_value.transpose(1, 2).contiguous() padded_value = padded_value.transpose(1, 2).contiguous()
zeros_to_add = torch.zeros(
padded_key.size(0),
self.num_key_value_heads,
1,
self.head_size,
dtype=padded_key.dtype,
device=padded_key.device,
)
key_states = torch.cat([padded_key, zeros_to_add], dim=2)
value_states = torch.cat([padded_value, zeros_to_add], dim=2)
# Compute attention # Compute attention
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
padded_query, padded_query,
key_states, padded_key,
value_states, padded_value,
attn_mask=attention_mask, attn_mask=attention_mask,
scale=self.softmax_scale, scale=self.softmax_scale,
enable_gqa=self.enable_gqa, enable_gqa=self.enable_gqa,
@ -327,6 +317,7 @@ class FlashGemma3Attention(torch.nn.Module):
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.window_size,
) )
return self.o_proj( return self.o_proj(
@ -513,6 +504,7 @@ class FlashGemma3Model(torch.nn.Module):
max_s: int, max_s: int,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
attention_mask_local: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -525,25 +517,6 @@ class FlashGemma3Model(torch.nn.Module):
position_ids, max_s, hidden_states.dtype position_ids, max_s, hidden_states.dtype
) )
# apply sliding window mask if needed
if layer.self_attn.window_size > 0 and attention_mask is not None:
min_dtype = torch.finfo(hidden_states.dtype).min
# prefill may be larger than sliding window
effective_seq_len = max(
position_ids.shape[0], self.layers[i].self_attn.window_size
)
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.layers[i].self_attn.window_size,
)
attention_mask = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
offset = max(0, position_ids.shape[0] - effective_seq_len)
attention_mask = attention_mask[
:, :, offset : offset + effective_seq_len
]
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -556,7 +529,11 @@ class FlashGemma3Model(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
adapter_data, adapter_data,
attention_mask, (
attention_mask
if self.layers[i].self_attn.window_size == -1
else attention_mask_local
),
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -723,24 +700,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
) )
def get_image_token_mask(self, input_ids):
device = input_ids.device
start_token_id = self.config.boi_token_index
K = self.config.mm_tokens_per_image
mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device)
start_positions = (input_ids == start_token_id).nonzero(as_tuple=True)[0]
mask_indices = start_positions.unsqueeze(1) + torch.arange(
1, K + 1, device=device
).unsqueeze(0)
valid_mask = mask_indices < input_ids.size(0)
mask_indices = mask_indices[valid_mask]
mask[mask_indices] = True
return mask
def get_attention_mask( def get_attention_mask(
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
): ):
@ -751,7 +710,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
batch_size = len(lengths) batch_size = len(lengths)
sequence_length = max(lengths) sequence_length = max(lengths)
target_length = max_s target_length = sequence_length
# Create the padding mask from the computed lengths. # Create the padding mask from the computed lengths.
# pad_mask: [batch, sequence_length] where True indicates valid tokens. # pad_mask: [batch, sequence_length] where True indicates valid tokens.
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
@ -847,7 +806,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
# # Determine the maximum sequence length (after padding) from query. # # Determine the maximum sequence length (after padding) from query.
# sequence_length = max(lengths) # sequence_length = max(lengths)
# target_length = max_s # target_length = sequence_length
# # Create the padding mask from the computed lengths. # # Create the padding mask from the computed lengths.
# # pad_mask: [batch, sequence_length] where True indicates valid tokens. # # pad_mask: [batch, sequence_length] where True indicates valid tokens.
@ -885,6 +844,26 @@ class Gemma3ForConditionalGeneration(nn.Module):
# input_ids.device # input_ids.device
# ) # )
if attention_mask is not None:
min_dtype = torch.finfo(inputs_embeds.dtype).min
# prefill may be larger than sliding window
effective_seq_len = max(
position_ids.shape[0], self.config.text_config.sliding_window
)
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.config.text_config.sliding_window,
)
attention_mask_local = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
offset = max(0, position_ids.shape[0] - effective_seq_len)
attention_mask_local = attention_mask_local[
:, :, :, offset : offset + effective_seq_len
]
else:
attention_mask_local = None
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
@ -895,6 +874,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
attention_mask=attention_mask, attention_mask=attention_mask,
attention_mask_local=attention_mask_local,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -242,6 +242,7 @@ class MistralAttention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.max_past,
) )
return self.o_proj( return self.o_proj(

View File

@ -290,6 +290,7 @@ class MixtralAttention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.max_past,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -185,6 +185,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.max_past,
) )
return self.o_proj( return self.o_proj(

View File

@ -291,6 +291,7 @@ class Starcoder2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.max_past,
) )
return self.o_proj( return self.o_proj(

View File

@ -82,7 +82,7 @@ class Gemma3Processor(ProcessorMixin):
do_rescale=False, do_rescale=False,
resample=PILImageResampling.BILINEAR, resample=PILImageResampling.BILINEAR,
) )
# import ipdb; ipdb.set_trace()
self.image_token_id = tokenizer.image_token_id self.image_token_id = tokenizer.image_token_id
image_tokens_expanded = "".join( image_tokens_expanded = "".join(
[tokenizer.image_token] * num_mm_soft_tokens_per_image [tokenizer.image_token] * num_mm_soft_tokens_per_image
@ -91,8 +91,6 @@ class Gemma3Processor(ProcessorMixin):
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
) )
# import ipdb; ipdb.set_trace()
self.image_processor = image_processor self.image_processor = image_processor
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.chat_template = chat_template self.chat_template = chat_template

View File

@ -633,7 +633,7 @@ class Qwen2_5VisionModel(nn.Module):
config=config, config=config,
weights=weights, weights=weights,
) )
# import ipdb; ipdb.set_trace()
self.temporal_patch_size = config.temporal_patch_size self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels self.in_channels = config.in_channels

View File

@ -83,24 +83,11 @@ from text_generation_server.models.metadata_kernels import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
def small_power_of_2(n: int): def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1) return 1 << ((n - 1).bit_length() - 1)
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
def get_sliding_windows() -> int:
global SLIDING_WINDOW
return SLIDING_WINDOW
def init_cpu_threads_env(rank_id: int, world_size: int): def init_cpu_threads_env(rank_id: int, world_size: int):
import importlib.util import importlib.util
@ -1002,10 +989,8 @@ class FlashCausalLMBatch(Batch):
self.slot_indices, self.slot_indices,
) )
sliding_window = get_sliding_windows()
position_ids = [] position_ids = []
slot_indices = [] slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True all_prefill_logprobs = True
no_prefill_logprobs = True no_prefill_logprobs = True
prefill_cu_outlens = [0] prefill_cu_outlens = [0]
@ -1064,14 +1049,6 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_slot_tokens += len(request_slots) cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling # Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling prefill_logprobs = r.prefill_logprobs and request_prefilling
@ -1085,9 +1062,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index() ADAPTER_TO_INDEX = get_adapter_to_index()
if ADAPTER_TO_INDEX: if ADAPTER_TO_INDEX:
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
@ -1151,24 +1125,18 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
if slot_indices: if slot_indices:
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else: else:
if position_ids: if position_ids:
position_ids = position_ids[0] position_ids = position_ids[0]
if slot_indices: if slot_indices:
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
if not has_triton(): if not has_triton():
self.position_ids = position_ids.to(device) self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device) self.slot_indices = slot_indices.to(device)
self.prefill_cu_outlens = prefill_cu_outlens self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = ( self.prefill_cache_indices = None
prefill_cache_indices.to(device) if sliding_window is not None else None
)
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
@ -1306,9 +1274,7 @@ class FlashCausalLM(Model):
if text_config is not None: if text_config is not None:
config = text_config config = text_config
if getattr(config, "sliding_window", None) is not None: if getattr(config, "sliding_window", None) is None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
@ -2500,7 +2466,6 @@ class FlashCausalLM(Model):
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_dtype=self.kv_cache_dtype, kv_dtype=self.kv_cache_dtype,
q_dtype=self.dtype, q_dtype=self.dtype,
window_left=self.sliding_window,
) )
else: else:
assert input_lengths_tensor is not None assert input_lengths_tensor is not None
@ -2514,5 +2479,4 @@ class FlashCausalLM(Model):
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
q_dtype=self.dtype, q_dtype=self.dtype,
window_left=self.sliding_window,
) )

View File

@ -110,7 +110,7 @@ class Model(ABC):
requires_padding=self.requires_padding, requires_padding=self.requires_padding,
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=None, # Setting this parameter to None disabled the block logic with sliding window.
speculate=self.speculate, speculate=self.speculate,
support_chunking=self.support_chunking, support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING, use_prefix_caching=PREFIX_CACHING,