mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: adjust signatures with types
This commit is contained in:
parent
6cb0cb68b4
commit
58f5f2ee27
@ -558,7 +558,7 @@ def apply_llama3_scaling(
|
|||||||
|
|
||||||
|
|
||||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||||
def __init__(self, inv_freq, scaling_factor, sections):
|
def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor)
|
||||||
self.sections = sections
|
self.sections = sections
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
@ -586,7 +586,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
|
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
|
||||||
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
|
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||||
|
):
|
||||||
# always cache the cos/sin for the full sequence length to avoid
|
# always cache the cos/sin for the full sequence length to avoid
|
||||||
# recomputing if the sequence length is smaller than the cached one
|
# recomputing if the sequence length is smaller than the cached one
|
||||||
if (
|
if (
|
||||||
|
Loading…
Reference in New Issue
Block a user