diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index f3e35c4c..b722174e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,12 +9,9 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig # Flash attention imports -import rotary_emb import flash_attn_cuda import dropout_layer_norm -from flash_attn.layers.rotary import RotaryEmbedding - class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -190,53 +187,12 @@ class TensorParallelEmbedding(nn.Embedding): return out -class PositionRotaryEmbedding(RotaryEmbedding): - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - cos = torch.index_select(self._cos_cached, 0, position_ids) - sin = torch.index_select(self._sin_cached, 0, position_ids) - return cos.unsqueeze(1), sin.unsqueeze(1) - - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv - class FlashNeoxAttention(torch.nn.Module): def __init__( self, num_heads, hidden_size, - rotary_pct, - rotary_emb_base, process_group=None, reduce=True, ): @@ -245,13 +201,11 @@ class FlashNeoxAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - rotary_ndims = int(self.head_size * rotary_pct) - self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) self.softmax_scale = self.head_size ** (-0.5) if process_group is None: self.query_key_value = FastLinear(hidden_size, 3 * hidden_size) - self.dense = FastLinear(hidden_size, hidden_size) + self.c_proj = FastLinear(hidden_size, hidden_size) else: self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear( @@ -259,7 +213,7 @@ class FlashNeoxAttention(torch.nn.Module): 3 * hidden_size, process_group=process_group, ) - self.dense = TensorParallelRowLinear( + self.c_proj = TensorParallelRowLinear( hidden_size, hidden_size, process_group=process_group, reduce=reduce )