mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
wip
This commit is contained in:
parent
f41ab12783
commit
5dfc9c7613
@ -9,12 +9,9 @@ from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
|
||||
# Flash attention imports
|
||||
import rotary_emb
|
||||
import flash_attn_cuda
|
||||
import dropout_layer_norm
|
||||
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
@ -190,53 +187,12 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||
return out
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(RotaryEmbedding):
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = qkv[:, 0, :, :rotary_dim]
|
||||
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
|
||||
k1 = qkv[:, 1, :, :rotary_dim]
|
||||
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
return qkv
|
||||
|
||||
|
||||
class FlashNeoxAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
rotary_pct,
|
||||
rotary_emb_base,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
@ -245,13 +201,11 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
rotary_ndims = int(self.head_size * rotary_pct)
|
||||
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
|
||||
self.dense = FastLinear(hidden_size, hidden_size)
|
||||
self.c_proj = FastLinear(hidden_size, hidden_size)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
@ -259,7 +213,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
3 * hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
self.c_proj = TensorParallelRowLinear(
|
||||
hidden_size, hidden_size, process_group=process_group, reduce=reduce
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user