diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index f9e1f06c..b28aa68a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -95,13 +95,6 @@ class FlashNeoxAttention(torch.nn.Module): rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) - dtype = weights.dtype - weights.dtype = torch.float32 - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) - weights.dtype = dtype - self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv(