diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index eea5f787..3ee344e4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.rotary_dim = int(config.rotary_pct * self.head_size) + if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() - self.rotary_emb = PositionRotaryEmbedding.load( - config=config, prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.rotary_dim, + base=config.rotary_emb_base, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5)