mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
freaking rotary
This commit is contained in:
parent
424e1b41a2
commit
2e7f6e8012
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.15.0"
|
tokenizers = "^0.15.0"
|
||||||
huggingface-hub = "^0.19.3"
|
huggingface-hub = "^0.19.3"
|
||||||
transformers = "^4.38"
|
transformers = { git = "https://github.com/huggingface/transformers", rev = "517a3e670d8fc11374895e870dd0dd041467c7fe" }
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
@ -53,8 +53,7 @@ class CohereLayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
if True:
|
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
-1, self.weight.shape[0], self.weight.shape[1]
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
)
|
)
|
||||||
@ -147,6 +146,93 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CohereRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
device=None,
|
||||||
|
scaling_factor=1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
self.base
|
||||||
|
** (
|
||||||
|
torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
|
||||||
|
/ self.dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, device_type, position_ids):
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
inv_freq_expanded = (
|
||||||
|
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
)
|
||||||
|
position_ids_expanded = position_ids[None, None, :].float()
|
||||||
|
|
||||||
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
# See https://github.com/huggingface/transformers/pull/29285
|
||||||
|
device_type = (
|
||||||
|
device_type
|
||||||
|
if isinstance(device_type, str) and device_type != "mps"
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
freqs = (
|
||||||
|
inv_freq_expanded.float().to(position_ids.device)
|
||||||
|
@ position_ids_expanded.float()
|
||||||
|
).transpose(1, 2)
|
||||||
|
emb = torch.repeat_interleave(freqs, 2, dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
return cos[0], sin[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
# Split and rotate
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
||||||
|
return rot_x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
dtype = q.dtype
|
||||||
|
q = q.float()
|
||||||
|
k = k.float()
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class FlashCohereAttention(torch.nn.Module):
|
class FlashCohereAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -232,14 +318,16 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
query = query.reshape(-1, self.head_size)
|
query = query.reshape(-1, self.head_size)
|
||||||
key = key.reshape(-1, self.head_size)
|
key = key.reshape(-1, self.head_size)
|
||||||
query = self.q_norm(query)
|
query = self.q_norm(query.contiguous())
|
||||||
key = self.k_norm(key)
|
key = self.k_norm(key.contiguous())
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, key, cos, sin)
|
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
# self.rotary_emb(query, key, true_cos.reshape(*cos.shape), true_sin.reshape(*sin.shape))
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
@ -399,6 +487,11 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
self.rotary_true = CohereRotaryEmbedding(
|
||||||
|
self.head_size,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=config.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -415,9 +508,10 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
# cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
# position_ids, max_s, hidden_states.dtype
|
||||||
)
|
# )
|
||||||
|
cos, sin = self.rotary_true(hidden_states.device.type, position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
Loading…
Reference in New Issue
Block a user