From 58f5f2ee27ecb24b3adb373d6f4980efe6064f90 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 4 Feb 2025 00:30:47 +0000 Subject: [PATCH] fix: adjust signatures with types --- server/text_generation_server/layers/rotary.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 576aeb52..f38f6859 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -558,7 +558,7 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): - def __init__(self, inv_freq, scaling_factor, sections): + def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list): super().__init__(inv_freq, scaling_factor) self.sections = sections self._cos_cached = None @@ -586,7 +586,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) - def _update_cos_sin_cache(self, dtype, device, seqlen): + def _update_cos_sin_cache( + self, dtype: torch.dtype, device: torch.device, seqlen: int + ): # always cache the cos/sin for the full sequence length to avoid # recomputing if the sequence length is smaller than the cached one if (