fix: prefer gemma rotary embed and split attention weight

This commit is contained in:
drbh 2024-05-14 03:15:32 +00:00 committed by Nicolas Patry
parent 6e8a2110f8
commit 5b3b8fd7b6

View File

@ -159,6 +159,107 @@ def _load_gqa(config, prefix: str, weights):
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base
** (
torch.arange(
0, self.dim, 2, dtype=torch.int64, device=x.device
).float()
/ self.dim
)
)
position_ids = position_ids.unsqueeze(0)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
# import ipdb
# ipdb.set_trace()
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class FlashGemmaAttention(torch.nn.Module):
def __init__(
self,
@ -170,9 +271,16 @@ class FlashGemmaAttention(torch.nn.Module):
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
# self._rotary_emb = PositionRotaryEmbedding.static(
# config=config,
# dim=self.head_size,
# base=config.rope_theta,
# device=weights.device,
# )
self.rotary_emb = GemmaRotaryEmbedding(
dim=self.head_size,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
device=weights.device,
)
@ -189,7 +297,21 @@ class FlashGemmaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
# TODO: prefer this implementation after debugging
# self.query_key_value = load_attention(config, prefix, weights)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
)
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
)
self.o_proj = TensorParallelRowLinear.load(
config,
@ -201,6 +323,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
def forward(
self,
@ -215,25 +338,54 @@ class FlashGemmaAttention(torch.nn.Module):
max_s,
):
global count
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
# TODO: replace with better implementation after debugging
tgt_len, src_len = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = (
query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2)
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
value_states = (
value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
)
# reshape for flash/paged attention
kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2)
# cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, None
)
# replace the kv with the rotated kv
kv[:, 0] = key_states.squeeze(0).transpose(0, 1)
query = query_states.squeeze(0).transpose(0, 1)
# TODO: remove after debugging
# ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0")
# ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1))
# ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0")
# ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1))
# ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0")
# ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1))
if count > 0:
import ipdb
ipdb.set_trace()
# looks good prior to attention
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# ipdb.set_trace()
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
)
# output tensor
@ -264,10 +416,7 @@ class FlashGemmaAttention(torch.nn.Module):
input_lengths,
max_s,
)
if count > 0:
import ipdb
ipdb.set_trace()
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -427,14 +576,14 @@ class FlashGemmaModel(torch.nn.Module):
global count
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
# TODO: prefer a single rotary embedding implementation after debugging
cos, sin = self.layers[i].self_attn.rotary_emb(
hidden_states,
position_ids,
)
hidden_states, residual = layer(
hidden_states,
residual,