mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: prefer gemma rotary embed and split attention weight
This commit is contained in:
parent
6e8a2110f8
commit
5b3b8fd7b6
@ -159,6 +159,107 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.register_buffer("inv_freq", None, persistent=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
if self.inv_freq is None:
|
||||||
|
self.inv_freq = 1.0 / (
|
||||||
|
self.base
|
||||||
|
** (
|
||||||
|
torch.arange(
|
||||||
|
0, self.dim, 2, dtype=torch.int64, device=x.device
|
||||||
|
).float()
|
||||||
|
/ self.dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
inv_freq_expanded = (
|
||||||
|
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
# See https://github.com/huggingface/transformers/pull/29285
|
||||||
|
device_type = x.device.type
|
||||||
|
device_type = (
|
||||||
|
device_type
|
||||||
|
if isinstance(device_type, str) and device_type != "mps"
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
|
||||||
|
# import ipdb
|
||||||
|
# ipdb.set_trace()
|
||||||
|
freqs = (
|
||||||
|
inv_freq_expanded.float() @ position_ids_expanded.float()
|
||||||
|
).transpose(1, 2)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaAttention(torch.nn.Module):
|
class FlashGemmaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -170,9 +271,16 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
# self._rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
# config=config,
|
||||||
|
# dim=self.head_size,
|
||||||
|
# base=config.rope_theta,
|
||||||
|
# device=weights.device,
|
||||||
|
# )
|
||||||
|
|
||||||
|
self.rotary_emb = GemmaRotaryEmbedding(
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
@ -189,7 +297,21 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
# TODO: prefer this implementation after debugging
|
||||||
|
# self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.k_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.v_proj = TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -201,6 +323,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -215,25 +338,54 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
global count
|
global count
|
||||||
qkv = self.query_key_value(hidden_states)
|
|
||||||
query, kv = qkv.split(
|
# TODO: replace with better implementation after debugging
|
||||||
[
|
tgt_len, src_len = hidden_states.size()
|
||||||
self.head_size * self.num_heads,
|
query_states = self.q_proj(hidden_states)
|
||||||
2 * self.head_size * self.num_key_value_heads,
|
key_states = self.k_proj(hidden_states)
|
||||||
],
|
value_states = self.v_proj(hidden_states)
|
||||||
dim=1,
|
|
||||||
|
query_states = (
|
||||||
|
query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2)
|
||||||
)
|
)
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
value_states = (
|
||||||
|
value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape for flash/paged attention
|
||||||
|
kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2)
|
||||||
|
|
||||||
|
# cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# replace the kv with the rotated kv
|
||||||
|
kv[:, 0] = key_states.squeeze(0).transpose(0, 1)
|
||||||
|
query = query_states.squeeze(0).transpose(0, 1)
|
||||||
|
|
||||||
|
# TODO: remove after debugging
|
||||||
|
# ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0")
|
||||||
|
# ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1))
|
||||||
|
|
||||||
|
# ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0")
|
||||||
|
# ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1))
|
||||||
|
|
||||||
|
# ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0")
|
||||||
|
# ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1))
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
import ipdb
|
import ipdb
|
||||||
|
|
||||||
ipdb.set_trace()
|
# ipdb.set_trace()
|
||||||
# looks good prior to attention
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0],
|
||||||
|
kv[:, 1],
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
slots,
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
@ -264,10 +416,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if count > 0:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
@ -427,14 +576,14 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
global count
|
global count
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
|
||||||
# Avoid to index in each layer
|
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
# TODO: prefer a single rotary embedding implementation after debugging
|
||||||
|
cos, sin = self.layers[i].self_attn.rotary_emb(
|
||||||
|
hidden_states,
|
||||||
|
position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
|
Loading…
Reference in New Issue
Block a user