mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
(fix) sliding window attention
This commit is contained in:
parent
f91434e99b
commit
ff82f0f84c
@ -14,8 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||||
- [Gemma3](https://huggingface.co/collections/google/gemma-3)
|
- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
|
||||||
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3)
|
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
|
||||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||||
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
||||||
|
@ -699,7 +699,7 @@ fn image_tokens(
|
|||||||
// TODO: prefer using the config to determine the number of features
|
// TODO: prefer using the config to determine the number of features
|
||||||
let num_mm_soft_tokens_per_image = 256;
|
let num_mm_soft_tokens_per_image = 256;
|
||||||
format!(
|
format!(
|
||||||
"\n\n<start_of_image>{:?}<end_of_image>\n\n",
|
"\n\n<start_of_image>{}<end_of_image>\n\n",
|
||||||
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
|
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -205,7 +205,6 @@ class LoraWeights(AdapterWeights):
|
|||||||
lora_a_list = [None] * nlayers
|
lora_a_list = [None] * nlayers
|
||||||
lora_b_list = [None] * nlayers
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
for layer_id in range(nlayers):
|
for layer_id in range(nlayers):
|
||||||
key = (layer_id, layer_type)
|
key = (layer_id, layer_type)
|
||||||
if key not in target_to_layer:
|
if key not in target_to_layer:
|
||||||
|
@ -38,6 +38,7 @@ def paged_attention(
|
|||||||
*,
|
*,
|
||||||
kv_scales: KVScales,
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
|
window_size_left: Optional[int] = -1,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
@ -79,12 +80,15 @@ def paged_attention(
|
|||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
|
window_size_left=window_size_left,
|
||||||
)
|
)
|
||||||
elif ATTENTION == "flashdecoding":
|
elif ATTENTION == "flashdecoding":
|
||||||
max_q = 1
|
max_q = 1
|
||||||
max_k = max_s
|
max_k = max_s
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
window_size_right = -1 if window_size_left == -1 else 0
|
||||||
|
|
||||||
# TODO fixme when flash contains the fix.
|
# TODO fixme when flash contains the fix.
|
||||||
# Number of splits is not correctly handled
|
# Number of splits is not correctly handled
|
||||||
# by the current path
|
# by the current path
|
||||||
@ -109,8 +113,8 @@ def paged_attention(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
False, # zero_tensors
|
False, # zero_tensors
|
||||||
True, # causal
|
True, # causal
|
||||||
-1, # Window_left
|
window_size_left, # Window_left
|
||||||
-1, # Window right
|
window_size_right, # Window right
|
||||||
softcap,
|
softcap,
|
||||||
False, # return softmax
|
False, # return softmax
|
||||||
None, # generator
|
None, # generator
|
||||||
@ -253,6 +257,7 @@ def attention(
|
|||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
|
window_size_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we are using flashdecoding or paged, we always use flash-attn for
|
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||||
|
@ -52,7 +52,6 @@ def use_prefill_with_paged_kv_state(
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
kv_dtype: torch.dtype,
|
kv_dtype: torch.dtype,
|
||||||
q_dtype: torch.dtype,
|
q_dtype: torch.dtype,
|
||||||
window_left: int,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to set the active flashinfer prefill state to the given
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
@ -95,7 +94,6 @@ def use_prefill_with_paged_kv_state(
|
|||||||
kv_data_type=kv_dtype,
|
kv_data_type=kv_dtype,
|
||||||
q_data_type=q_dtype,
|
q_data_type=q_dtype,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
window_left=-1 if window_left is None else window_left,
|
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -172,7 +170,6 @@ def use_decode_state(
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
kv_cache_dtype: torch.dtype,
|
kv_cache_dtype: torch.dtype,
|
||||||
q_dtype: torch.dtype,
|
q_dtype: torch.dtype,
|
||||||
window_left: int,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to set the active flashinfer decoding state to the given
|
Context manager to set the active flashinfer decoding state to the given
|
||||||
@ -209,7 +206,6 @@ def use_decode_state(
|
|||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
q_data_type=q_dtype,
|
q_data_type=q_dtype,
|
||||||
window_left=-1 if window_left is None else window_left,
|
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
@ -272,12 +272,12 @@ class ModelType(enum.Enum):
|
|||||||
GEMMA3 = {
|
GEMMA3 = {
|
||||||
"type": "gemma3",
|
"type": "gemma3",
|
||||||
"name": "Gemma3",
|
"name": "Gemma3",
|
||||||
"url": "https://huggingface.co/collections/google/gemma-3",
|
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d-release-67c6c6f89c4f76621268bb6d",
|
||||||
}
|
}
|
||||||
GEMMA3_TEXT = {
|
GEMMA3_TEXT = {
|
||||||
"type": "gemma3_text",
|
"type": "gemma3_text",
|
||||||
"name": "Gemma3 Text",
|
"name": "Gemma3 Text",
|
||||||
"url": "https://huggingface.co/collections/google/gemma-3",
|
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d-release-67c6c6f89c4f76621268bb6d",
|
||||||
}
|
}
|
||||||
COHERE = {
|
COHERE = {
|
||||||
"type": "cohere",
|
"type": "cohere",
|
||||||
|
@ -287,6 +287,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
|
window_size_left=self.window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -281,22 +281,12 @@ class FlashGemma3Attention(torch.nn.Module):
|
|||||||
padded_query = padded_query.transpose(1, 2).contiguous()
|
padded_query = padded_query.transpose(1, 2).contiguous()
|
||||||
padded_key = padded_key.transpose(1, 2).contiguous()
|
padded_key = padded_key.transpose(1, 2).contiguous()
|
||||||
padded_value = padded_value.transpose(1, 2).contiguous()
|
padded_value = padded_value.transpose(1, 2).contiguous()
|
||||||
zeros_to_add = torch.zeros(
|
|
||||||
padded_key.size(0),
|
|
||||||
self.num_key_value_heads,
|
|
||||||
1,
|
|
||||||
self.head_size,
|
|
||||||
dtype=padded_key.dtype,
|
|
||||||
device=padded_key.device,
|
|
||||||
)
|
|
||||||
key_states = torch.cat([padded_key, zeros_to_add], dim=2)
|
|
||||||
value_states = torch.cat([padded_value, zeros_to_add], dim=2)
|
|
||||||
|
|
||||||
# Compute attention
|
# Compute attention
|
||||||
attn_output = F.scaled_dot_product_attention(
|
attn_output = F.scaled_dot_product_attention(
|
||||||
padded_query,
|
padded_query,
|
||||||
key_states,
|
padded_key,
|
||||||
value_states,
|
padded_value,
|
||||||
attn_mask=attention_mask,
|
attn_mask=attention_mask,
|
||||||
scale=self.softmax_scale,
|
scale=self.softmax_scale,
|
||||||
enable_gqa=self.enable_gqa,
|
enable_gqa=self.enable_gqa,
|
||||||
@ -327,6 +317,7 @@ class FlashGemma3Attention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
|
window_size_left=self.window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -513,6 +504,7 @@ class FlashGemma3Model(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask_local: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -525,25 +517,6 @@ class FlashGemma3Model(torch.nn.Module):
|
|||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply sliding window mask if needed
|
|
||||||
if layer.self_attn.window_size > 0 and attention_mask is not None:
|
|
||||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
|
||||||
# prefill may be larger than sliding window
|
|
||||||
effective_seq_len = max(
|
|
||||||
position_ids.shape[0], self.layers[i].self_attn.window_size
|
|
||||||
)
|
|
||||||
sliding_window_mask = torch.tril(
|
|
||||||
torch.ones_like(attention_mask, dtype=torch.bool),
|
|
||||||
diagonal=-self.layers[i].self_attn.window_size,
|
|
||||||
)
|
|
||||||
attention_mask = torch.where(
|
|
||||||
sliding_window_mask, min_dtype, attention_mask
|
|
||||||
)
|
|
||||||
offset = max(0, position_ids.shape[0] - effective_seq_len)
|
|
||||||
attention_mask = attention_mask[
|
|
||||||
:, :, offset : offset + effective_seq_len
|
|
||||||
]
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
@ -556,7 +529,11 @@ class FlashGemma3Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
attention_mask,
|
(
|
||||||
|
attention_mask
|
||||||
|
if self.layers[i].self_attn.window_size == -1
|
||||||
|
else attention_mask_local
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -723,24 +700,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_image_token_mask(self, input_ids):
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
start_token_id = self.config.boi_token_index
|
|
||||||
K = self.config.mm_tokens_per_image
|
|
||||||
|
|
||||||
mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device)
|
|
||||||
start_positions = (input_ids == start_token_id).nonzero(as_tuple=True)[0]
|
|
||||||
mask_indices = start_positions.unsqueeze(1) + torch.arange(
|
|
||||||
1, K + 1, device=device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
valid_mask = mask_indices < input_ids.size(0)
|
|
||||||
mask_indices = mask_indices[valid_mask]
|
|
||||||
mask[mask_indices] = True
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def get_attention_mask(
|
def get_attention_mask(
|
||||||
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
|
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
|
||||||
):
|
):
|
||||||
@ -751,7 +710,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
batch_size = len(lengths)
|
batch_size = len(lengths)
|
||||||
|
|
||||||
sequence_length = max(lengths)
|
sequence_length = max(lengths)
|
||||||
target_length = max_s
|
target_length = sequence_length
|
||||||
# Create the padding mask from the computed lengths.
|
# Create the padding mask from the computed lengths.
|
||||||
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||||
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
||||||
@ -847,7 +806,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# # Determine the maximum sequence length (after padding) from query.
|
# # Determine the maximum sequence length (after padding) from query.
|
||||||
# sequence_length = max(lengths)
|
# sequence_length = max(lengths)
|
||||||
# target_length = max_s
|
# target_length = sequence_length
|
||||||
|
|
||||||
# # Create the padding mask from the computed lengths.
|
# # Create the padding mask from the computed lengths.
|
||||||
# # pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
# # pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||||
@ -885,6 +844,26 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
# input_ids.device
|
# input_ids.device
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
min_dtype = torch.finfo(inputs_embeds.dtype).min
|
||||||
|
# prefill may be larger than sliding window
|
||||||
|
effective_seq_len = max(
|
||||||
|
position_ids.shape[0], self.config.text_config.sliding_window
|
||||||
|
)
|
||||||
|
sliding_window_mask = torch.tril(
|
||||||
|
torch.ones_like(attention_mask, dtype=torch.bool),
|
||||||
|
diagonal=-self.config.text_config.sliding_window,
|
||||||
|
)
|
||||||
|
attention_mask_local = torch.where(
|
||||||
|
sliding_window_mask, min_dtype, attention_mask
|
||||||
|
)
|
||||||
|
offset = max(0, position_ids.shape[0] - effective_seq_len)
|
||||||
|
attention_mask_local = attention_mask_local[
|
||||||
|
:, :, :, offset : offset + effective_seq_len
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
attention_mask_local = None
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -895,6 +874,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
attention_mask_local=attention_mask_local,
|
||||||
)
|
)
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -242,6 +242,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -290,6 +290,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -185,6 +185,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -291,6 +291,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -82,7 +82,7 @@ class Gemma3Processor(ProcessorMixin):
|
|||||||
do_rescale=False,
|
do_rescale=False,
|
||||||
resample=PILImageResampling.BILINEAR,
|
resample=PILImageResampling.BILINEAR,
|
||||||
)
|
)
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
self.image_token_id = tokenizer.image_token_id
|
self.image_token_id = tokenizer.image_token_id
|
||||||
image_tokens_expanded = "".join(
|
image_tokens_expanded = "".join(
|
||||||
[tokenizer.image_token] * num_mm_soft_tokens_per_image
|
[tokenizer.image_token] * num_mm_soft_tokens_per_image
|
||||||
@ -91,8 +91,6 @@ class Gemma3Processor(ProcessorMixin):
|
|||||||
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
|
|
||||||
self.image_processor = image_processor
|
self.image_processor = image_processor
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
|
@ -633,7 +633,7 @@ class Qwen2_5VisionModel(nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
self.temporal_patch_size = config.temporal_patch_size
|
self.temporal_patch_size = config.temporal_patch_size
|
||||||
self.spatial_patch_size = config.spatial_patch_size
|
self.spatial_patch_size = config.spatial_patch_size
|
||||||
self.in_channels = config.in_channels
|
self.in_channels = config.in_channels
|
||||||
|
@ -83,24 +83,11 @@ from text_generation_server.models.metadata_kernels import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
# Will be set in init
|
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def small_power_of_2(n: int):
|
def small_power_of_2(n: int):
|
||||||
return 1 << ((n - 1).bit_length() - 1)
|
return 1 << ((n - 1).bit_length() - 1)
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int):
|
|
||||||
global SLIDING_WINDOW
|
|
||||||
SLIDING_WINDOW = sliding_window
|
|
||||||
|
|
||||||
|
|
||||||
def get_sliding_windows() -> int:
|
|
||||||
global SLIDING_WINDOW
|
|
||||||
return SLIDING_WINDOW
|
|
||||||
|
|
||||||
|
|
||||||
def init_cpu_threads_env(rank_id: int, world_size: int):
|
def init_cpu_threads_env(rank_id: int, world_size: int):
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
@ -1002,10 +989,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.slot_indices,
|
self.slot_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
sliding_window = get_sliding_windows()
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
prefill_cache_indices = []
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
no_prefill_logprobs = True
|
no_prefill_logprobs = True
|
||||||
prefill_cu_outlens = [0]
|
prefill_cu_outlens = [0]
|
||||||
@ -1064,14 +1049,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Update
|
# Update
|
||||||
cumulative_slot_tokens += len(request_slots)
|
cumulative_slot_tokens += len(request_slots)
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
|
||||||
if sliding_window is not None:
|
|
||||||
request_prefill_cache_indices = torch.arange(
|
|
||||||
cumulative_length + max(0, input_length - sliding_window),
|
|
||||||
cumulative_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill logprobs is ignored if the request is done prefilling
|
# Prefill logprobs is ignored if the request is done prefilling
|
||||||
prefill_logprobs = r.prefill_logprobs and request_prefilling
|
prefill_logprobs = r.prefill_logprobs and request_prefilling
|
||||||
|
|
||||||
@ -1085,9 +1062,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
if sliding_window is not None:
|
|
||||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
|
||||||
|
|
||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
if ADAPTER_TO_INDEX:
|
if ADAPTER_TO_INDEX:
|
||||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||||
@ -1151,24 +1125,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
if slot_indices:
|
if slot_indices:
|
||||||
slot_indices = torch.cat(slot_indices)
|
slot_indices = torch.cat(slot_indices)
|
||||||
if sliding_window is not None:
|
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
|
||||||
else:
|
else:
|
||||||
if position_ids:
|
if position_ids:
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
if slot_indices:
|
if slot_indices:
|
||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
if sliding_window is not None:
|
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
|
||||||
|
|
||||||
if not has_triton():
|
if not has_triton():
|
||||||
self.position_ids = position_ids.to(device)
|
self.position_ids = position_ids.to(device)
|
||||||
self.slot_indices = slot_indices.to(device)
|
self.slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
self.prefill_cu_outlens = prefill_cu_outlens
|
self.prefill_cu_outlens = prefill_cu_outlens
|
||||||
self.prefill_cache_indices = (
|
self.prefill_cache_indices = None
|
||||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
@ -1306,9 +1274,7 @@ class FlashCausalLM(Model):
|
|||||||
if text_config is not None:
|
if text_config is not None:
|
||||||
config = text_config
|
config = text_config
|
||||||
|
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
if getattr(config, "sliding_window", None) is None:
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
else:
|
|
||||||
config.sliding_window = None
|
config.sliding_window = None
|
||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
@ -2500,7 +2466,6 @@ class FlashCausalLM(Model):
|
|||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
kv_dtype=self.kv_cache_dtype,
|
kv_dtype=self.kv_cache_dtype,
|
||||||
q_dtype=self.dtype,
|
q_dtype=self.dtype,
|
||||||
window_left=self.sliding_window,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert input_lengths_tensor is not None
|
assert input_lengths_tensor is not None
|
||||||
@ -2514,5 +2479,4 @@ class FlashCausalLM(Model):
|
|||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
q_dtype=self.dtype,
|
q_dtype=self.dtype,
|
||||||
window_left=self.sliding_window,
|
|
||||||
)
|
)
|
||||||
|
@ -110,7 +110,7 @@ class Model(ABC):
|
|||||||
requires_padding=self.requires_padding,
|
requires_padding=self.requires_padding,
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
window_size=None, # Setting this parameter to None disabled the block logic with sliding window.
|
||||||
speculate=self.speculate,
|
speculate=self.speculate,
|
||||||
support_chunking=self.support_chunking,
|
support_chunking=self.support_chunking,
|
||||||
use_prefix_caching=PREFIX_CACHING,
|
use_prefix_caching=PREFIX_CACHING,
|
||||||
|
Loading…
Reference in New Issue
Block a user