diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a33c6c2d..a959cf20 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -103,7 +103,8 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.softmax_scale = self.head_size ** (-0.5) self.num_heads = self.num_heads // weights.process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 64bd3a40..f9e1f06c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -93,7 +93,15 @@ class FlashNeoxAttention(torch.nn.Module): self.num_heads = self.num_heads // weights.process_group.size() rotary_ndims = int(self.head_size * rotary_pct) - self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) + self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + + dtype = weights.dtype + weights.dtype = torch.float32 + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) + weights.dtype = dtype + self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 0146e5c3..1699622d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -297,7 +297,27 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb - class PositionRotaryEmbedding(RotaryEmbedding): + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq): + super().__init__() + + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + @staticmethod + def load(prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + return PositionRotaryEmbedding(inv_freq) + + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance)