fix: adjust signatures with types

This commit is contained in:
drbh 2025-02-04 00:30:47 +00:00
parent 6cb0cb68b4
commit 58f5f2ee27

View File

@ -558,7 +558,7 @@ def apply_llama3_scaling(
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(self, inv_freq, scaling_factor, sections):
def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):
super().__init__(inv_freq, scaling_factor)
self.sections = sections
self._cos_cached = None
@ -586,7 +586,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
def _update_cos_sin_cache(self, dtype, device, seqlen):
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (