fix: improve multimodal rotary embed caching

This commit is contained in:
drbh 2025-01-22 16:43:53 +00:00
parent 77ef543061
commit d12e075966

View File

@ -557,6 +557,8 @@ def apply_llama3_scaling(
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(self, inv_freq, scaling_factor, sections):
super().__init__(inv_freq, scaling_factor)
# expand the inv_freq for the 3 sections
self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1)
self.sections = sections * 2
self._cos_cached = None
self._sin_cached = None
@ -568,36 +570,41 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
cos: torch.Tensor,
sin: torch.Tensor,
):
mrope_section = self.sections
unsqueeze_dim = 1
split_cos = cos.split(mrope_section, dim=-1)
split_sin = sin.split(mrope_section, dim=-1)
cos = []
for i, m in enumerate(split_cos):
cos.append(m[i % 3])
cos = torch.cat(cos, dim=-1).unsqueeze(unsqueeze_dim)
sin = []
for i, m in enumerate(split_sin):
sin.append(m[i % 3])
sin = torch.cat(sin, dim=-1).unsqueeze(unsqueeze_dim)
q = query.transpose(0, 1).unsqueeze(0)
k = key.transpose(0, 1).unsqueeze(0)
# process multi-modal rotary embeddings
split_cos, split_sin = [
torch.split(t, self.sections, dim=-1) for t in (cos, sin)
]
cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze(
1
)
sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze(
1
)
# prepare input tensors
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
rotary_dim = cos.shape[-1]
q1 = q[..., :rotary_dim]
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
k1 = k[..., :rotary_dim]
k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1)
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached_exp.device != device
or self._cos_cached_exp.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
freqs = freqs.expand(3, -1, -1)
self._cos_cached_exp = freqs.cos().to(dtype)
self._sin_cached_exp = freqs.sin().to(dtype)
def get_cos_sin(
self,
position_ids: torch.Tensor,
@ -605,21 +612,16 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
dtype: torch.dtype,
):
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
inv_freq_expanded = (
self.inv_freq[None, None, :, None]
.float()
.expand(3, position_ids.shape[1], -1, 1)
# expand the position_ids to match the shape of the cached cos/sin
indices = (
position_ids.squeeze(1)
.unsqueeze(-1)
.expand(-1, -1, self._cos_cached_exp.shape[-1])
)
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
position_ids_expanded = position_ids[
:, :, None, :
].float() # shape (3, bs, 1, positions)
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
2, 3
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype), sin.to(dtype)
return cos_c, sin_c