mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: adjust signatures with types
This commit is contained in:
parent
6cb0cb68b4
commit
58f5f2ee27
@ -558,7 +558,7 @@ def apply_llama3_scaling(
|
||||
|
||||
|
||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
def __init__(self, inv_freq, scaling_factor, sections):
|
||||
def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.sections = sections
|
||||
self._cos_cached = None
|
||||
@ -586,7 +586,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
|
||||
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
def _update_cos_sin_cache(
|
||||
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||
):
|
||||
# always cache the cos/sin for the full sequence length to avoid
|
||||
# recomputing if the sequence length is smaller than the cached one
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user