mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-08-01 04:40:17 +00:00
fix: improve multimodal rotary embed caching
This commit is contained in:
parent
77ef543061
commit
d12e075966
@ -557,6 +557,8 @@ def apply_llama3_scaling(
|
||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
def __init__(self, inv_freq, scaling_factor, sections):
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
# expand the inv_freq for the 3 sections
|
||||
self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1)
|
||||
self.sections = sections * 2
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
@ -568,36 +570,41 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
mrope_section = self.sections
|
||||
unsqueeze_dim = 1
|
||||
|
||||
split_cos = cos.split(mrope_section, dim=-1)
|
||||
split_sin = sin.split(mrope_section, dim=-1)
|
||||
|
||||
cos = []
|
||||
for i, m in enumerate(split_cos):
|
||||
cos.append(m[i % 3])
|
||||
|
||||
cos = torch.cat(cos, dim=-1).unsqueeze(unsqueeze_dim)
|
||||
|
||||
sin = []
|
||||
for i, m in enumerate(split_sin):
|
||||
sin.append(m[i % 3])
|
||||
|
||||
sin = torch.cat(sin, dim=-1).unsqueeze(unsqueeze_dim)
|
||||
|
||||
q = query.transpose(0, 1).unsqueeze(0)
|
||||
k = key.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
# process multi-modal rotary embeddings
|
||||
split_cos, split_sin = [
|
||||
torch.split(t, self.sections, dim=-1) for t in (cos, sin)
|
||||
]
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze(
|
||||
1
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze(
|
||||
1
|
||||
)
|
||||
# prepare input tensors
|
||||
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = q[..., :rotary_dim]
|
||||
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
|
||||
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
|
||||
|
||||
k1 = k[..., :rotary_dim]
|
||||
k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1)
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# always cache the cos/sin for the full sequence length to avoid
|
||||
# recomputing if the sequence length is smaller than the cached one
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached_exp.device != device
|
||||
or self._cos_cached_exp.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
freqs = freqs.expand(3, -1, -1)
|
||||
self._cos_cached_exp = freqs.cos().to(dtype)
|
||||
self._sin_cached_exp = freqs.sin().to(dtype)
|
||||
|
||||
def get_cos_sin(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
@ -605,21 +612,16 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, None, :, None]
|
||||
.float()
|
||||
.expand(3, position_ids.shape[1], -1, 1)
|
||||
# expand the position_ids to match the shape of the cached cos/sin
|
||||
indices = (
|
||||
position_ids.squeeze(1)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, -1, self._cos_cached_exp.shape[-1])
|
||||
)
|
||||
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
|
||||
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
|
||||
|
||||
position_ids_expanded = position_ids[
|
||||
:, :, None, :
|
||||
].float() # shape (3, bs, 1, positions)
|
||||
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
|
||||
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
|
||||
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
|
||||
2, 3
|
||||
)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype), sin.to(dtype)
|
||||
return cos_c, sin_c
|
||||
|
Loading…
Reference in New Issue
Block a user