enable all the model. not testet yet

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-16 22:37:34 -07:00
parent 5d3653943c
commit a07e7437b6
18 changed files with 374 additions and 250 deletions

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers import (
@ -221,7 +222,8 @@ class FlashCohereAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, key, value = qkv.split(
@ -245,9 +247,16 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key,
value=value,
key=key_to_cache,
value=value_to_cache,
slots=slots,
kv_scales=self.kv_scales,
)
@ -274,8 +283,8 @@ class FlashCohereAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
@ -350,7 +359,8 @@ class FlashCohereLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -364,7 +374,8 @@ class FlashCohereLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
mlp_output = self.mlp(normed_hidden_states)
@ -416,15 +427,14 @@ class FlashCohereModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
@ -439,7 +449,8 @@ class FlashCohereModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -480,8 +491,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -493,7 +504,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -27,6 +27,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
FastLinear,
@ -312,7 +313,8 @@ class DbrxAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None:
@ -329,10 +331,14 @@ class DbrxAttention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -359,8 +365,8 @@ class DbrxAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -397,7 +403,8 @@ class DbrxNormAttentionNorm(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -411,7 +418,8 @@ class DbrxNormAttentionNorm(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
# faster post attention rms norm
@ -631,7 +639,8 @@ class DbrxLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
# Self Attention
attn_output, attn_res = self.attn(
@ -644,7 +653,8 @@ class DbrxLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
moe_output = self.moe(attn_output)
@ -688,15 +698,14 @@ class DbrxModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -710,7 +719,8 @@ class DbrxModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -743,8 +753,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -756,7 +766,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
@ -258,7 +259,8 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
@ -314,10 +316,15 @@ class DeepseekV2Attention(torch.nn.Module):
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key,
value=value,
key=key_to_cache,
value=value_to_cache,
slots=slots,
kv_scales=self.kv_scales,
)
@ -344,8 +351,8 @@ class DeepseekV2Attention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
# Remove padding.
@ -508,7 +515,8 @@ class DeepseekV2Layer(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -522,7 +530,8 @@ class DeepseekV2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
# faster post attention rms norm
@ -571,15 +580,14 @@ class DeepseekV2Model(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -593,7 +601,8 @@ class DeepseekV2Model(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -623,8 +632,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -636,7 +645,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
@ -258,7 +259,8 @@ class DeepseekV3Attention(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
@ -315,9 +317,15 @@ class DeepseekV3Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key,
value=value,
key=key_to_cache,
value=value_to_cache,
slots=slots,
kv_scales=self.kv_scales,
)
@ -344,8 +352,8 @@ class DeepseekV3Attention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
# Remove padding.
@ -517,7 +525,8 @@ class DeepseekV3Layer(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -531,7 +540,8 @@ class DeepseekV3Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
# faster post attention rms norm
@ -580,15 +590,14 @@ class DeepseekV3Model(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -602,7 +611,8 @@ class DeepseekV3Model(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -632,8 +642,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -645,7 +655,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -237,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -252,10 +254,14 @@ class FlashGemma2Attention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -284,9 +290,9 @@ class FlashGemma2Attention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
@ -399,8 +405,9 @@ class FlashGemma2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -414,8 +421,9 @@ class FlashGemma2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
# faster post attention rms norm
@ -467,16 +475,15 @@ class FlashGemma2Model(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor],
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -490,8 +497,9 @@ class FlashGemma2Model(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -538,8 +546,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -552,8 +560,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -209,7 +210,8 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -224,9 +226,14 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -254,8 +261,8 @@ class FlashGemmaAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -327,7 +334,8 @@ class FlashGemmaLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -341,7 +349,8 @@ class FlashGemmaLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
# faster post attention rms norm
@ -389,15 +398,14 @@ class FlashGemmaModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -411,7 +419,8 @@ class FlashGemmaModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -456,8 +465,8 @@ class FlashGemmaForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -470,7 +479,8 @@ class FlashGemmaForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -28,6 +28,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -215,7 +216,8 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -224,9 +226,16 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key,
value=value,
key=key_to_cache,
value=value_to_cache,
slots=slots,
kv_scales=self.kv_scales,
)
@ -253,8 +262,8 @@ class FlashGPT2Attention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -323,7 +332,8 @@ class FlashGPT2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -336,7 +346,8 @@ class FlashGPT2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states = attn_output + residual
@ -389,9 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -405,7 +415,8 @@ class FlashGPT2Model(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states = self.norm(hidden_states)
@ -442,7 +453,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -458,9 +469,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -158,7 +159,8 @@ class FlashGPTJAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -175,9 +177,16 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key,
value=value,
key=key_to_cache,
value=value_to_cache,
slots=slots,
kv_scales=self.kv_scales,
)
@ -204,8 +213,8 @@ class FlashGPTJAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -266,7 +275,8 @@ class FlashGPTJLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
@ -279,7 +289,8 @@ class FlashGPTJLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
feed_forward_hidden_states = self.mlp(hidden_states)
@ -326,16 +337,14 @@ class FlashGPTJModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -349,7 +358,8 @@ class FlashGPTJModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -380,8 +390,8 @@ class FlashGPTJForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -393,8 +403,8 @@ class FlashGPTJForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices=prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -206,7 +206,7 @@ class FlashLlamaAttention(torch.nn.Module):
seqlen,
adapter_data,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -447,7 +447,7 @@ class FlashLlamaLayer(nn.Module):
adapter_data,
cross_attention_states,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -559,8 +559,8 @@ class FlashLlamaModel(torch.nn.Module):
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -646,11 +646,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -180,9 +181,9 @@ class MistralAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -232,8 +233,8 @@ class MistralAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
@ -337,9 +338,9 @@ class MistralLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -353,9 +354,9 @@ class MistralLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
@ -405,17 +406,14 @@ class MistralModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -429,9 +427,9 @@ class MistralModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -480,13 +478,14 @@ class FlashMistralForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
@ -503,9 +502,8 @@ class FlashMistralForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
hpu_attention_meta,
adapter_data,
)
if lm_head_indices is not None:

View File

@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
@ -237,8 +238,8 @@ class MixtralAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -288,8 +289,8 @@ class MixtralAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -386,8 +387,8 @@ class MixtralLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -401,8 +402,8 @@ class MixtralLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
# faster post attention rms norm
@ -456,17 +457,14 @@ class MixtralModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -480,8 +478,8 @@ class MixtralModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -515,13 +513,14 @@ class FlashMixtralForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
@ -537,9 +536,8 @@ class FlashMixtralForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -149,7 +150,8 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -164,10 +166,14 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(query_rot, key_rot, cos, sin)
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
if prefill_cache_indices is not None:
qkv_to_cache = qkv[prefill_cache_indices]
else:
qkv_to_cache = qkv
kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
key=qkv_to_cache[:, 1],
value=qkv_to_cache[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
@ -194,8 +200,8 @@ class FlashNeoxAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -265,7 +271,8 @@ class FlashNeoXLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -279,7 +286,8 @@ class FlashNeoXLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -303,7 +311,8 @@ class FlashNeoXLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, residual = self.post_attention_layernorm(
@ -357,15 +366,14 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -379,7 +387,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
@ -411,7 +420,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -424,7 +433,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -162,7 +163,8 @@ class FlashPhiAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
# Compute query, key, value and split
qkv = self.query_key_value(hidden_states)
@ -188,9 +190,13 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -216,8 +222,8 @@ class FlashPhiAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -284,7 +290,8 @@ class FlashPhiLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
@ -297,7 +304,8 @@ class FlashPhiLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states = self.resid_dropout(attn_output).add(
@ -349,15 +357,14 @@ class FlashPhiModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -371,7 +378,8 @@ class FlashPhiModel(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -404,8 +412,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -417,7 +425,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -108,8 +109,8 @@ class Qwen2Attention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -159,8 +160,8 @@ class Qwen2Attention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -232,8 +233,8 @@ class Qwen2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states)
@ -247,8 +248,8 @@ class Qwen2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states = attn_output + residual
@ -298,16 +299,13 @@ class Qwen2Model(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
true_max_s,
hidden_states.dtype,
)
residual = None
@ -322,8 +320,8 @@ class Qwen2Model(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states)
@ -369,13 +367,15 @@ class Qwen2ForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
if prefill_cache_indices is not None and prefill_cache_indices.size(
0
) != slots.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
@ -393,9 +393,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -184,7 +185,8 @@ class FlashRWAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -201,9 +203,14 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -230,8 +237,8 @@ class FlashRWAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -305,7 +312,8 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -321,9 +329,14 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
key=kv_to_cache[:, :, 0].contiguous(),
value=kv_to_cache[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
)
@ -350,8 +363,8 @@ class FlashRWLargeAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(
@ -437,7 +450,8 @@ class FlashRWLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -451,7 +465,8 @@ class FlashRWLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
mlp_output = self.mlp(ln_hidden_states)
@ -473,7 +488,8 @@ class FlashRWLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if self.post_attention_layernorm is not None:
@ -560,7 +576,8 @@ class FlashRWLargeLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
# Layer norm.
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
@ -575,7 +592,8 @@ class FlashRWLargeLayer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
# MLP.
@ -636,15 +654,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.h):
@ -658,7 +675,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -688,8 +706,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -701,7 +719,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -271,7 +272,8 @@ class FlashMQAttention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.c_attn(hidden_states)
@ -284,9 +286,14 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
if prefill_cache_indices is not None:
key_value_to_cache = key_value[prefill_cache_indices]
else:
key_value_to_cache = key_value
kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
key=key_value_to_cache[:, 0],
value=key_value_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -313,8 +320,8 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -379,7 +386,8 @@ class Block(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.self_attn(
@ -389,7 +397,8 @@ class Block(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
@ -443,7 +452,8 @@ class FlashSantacoderModel(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -460,7 +470,8 @@ class FlashSantacoderModel(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -492,7 +503,7 @@ class FlashSantacoderForCausalLM(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -505,7 +516,8 @@ class FlashSantacoderForCausalLM(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
@ -237,9 +238,9 @@ class Starcoder2Attention(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -289,8 +290,8 @@ class Starcoder2Attention(torch.nn.Module):
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
@ -450,9 +451,9 @@ class Starcoder2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -466,9 +467,9 @@ class Starcoder2Layer(nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
@ -520,18 +521,15 @@ class Starcoder2Model(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
@ -545,9 +543,9 @@ class Starcoder2Model(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -594,13 +592,14 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
@ -616,10 +615,9 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -1009,24 +1009,22 @@ class FlashCausalLMBatch(Batch):
# padding to left to work with sliding window
# use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
# the right logit position
if device.type == "hpu":
input_ids_padded = None
input_ids_padded_length = None
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded = []
input_ids_padded_length = []
for input_id in self.input_ids:
padded = self.max_input_length - len(input_id)
input_id_padded = input_id
if padded > 0:
input_id_padded = [0] * padded + input_id_padded
input_ids_padded.append(input_id_padded)
input_ids_padded_length.append(padded)
input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64)
input_ids_padded = torch.tensor(
input_ids_padded, dtype=torch.int64, device=device
)
input_ids_padded = None
input_ids_padded_length = None
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded = []
input_ids_padded_length = []
for input_id in self.input_ids:
padded = self.max_input_length - len(input_id)
input_id_padded = input_id
if padded > 0:
input_id_padded = [0] * padded + input_id_padded
input_ids_padded.append(input_id_padded)
input_ids_padded_length.append(padded)
input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64)
input_ids_padded = torch.tensor(
input_ids_padded, dtype=torch.int64, device=device
)
if isinstance(self.input_ids, list):
if len(self) > 1:
@ -1084,7 +1082,7 @@ class FlashCausalLMBatch(Batch):
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
if device.type == "hpu" and input_ids_padded is not None:
if input_ids_padded is not None:
position_ids.append(
torch.ones(input_ids_padded_length[i], dtype=torch.int32)
)
@ -1111,7 +1109,7 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
if device.type == "hpu" and input_ids_padded is not None:
if input_ids_padded is not None:
# hpu need request_prefill_cache_indices to skip padding in kv cache
sliding_window = get_sliding_windows()
if sliding_window is None:
@ -1235,7 +1233,7 @@ class FlashCausalLMBatch(Batch):
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
if device.type == "hpu" and input_ids_padded is not None:
if input_ids_padded is not None:
self.input_ids = input_ids_padded
input_ids_padded_length_tensor = torch.cumsum(
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),