mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-22 17:10:18 +00:00
fix: improve position id init during cuda warmup for mrope and simplfy rotary forward
This commit is contained in:
parent
c75c01e9b9
commit
79a2c956de
@ -579,15 +579,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# prepare input tensors
|
# rotate half the sequence length
|
||||||
q, k = [x.transpose(0, 1) for x in (query, key)]
|
rot = cos.shape[-1] // 2
|
||||||
rotary_dim = cos.shape[-1]
|
q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1)
|
||||||
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
|
k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1)
|
||||||
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
|
|
||||||
k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1)
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
|
# apply the rotation
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True)
|
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
|
||||||
|
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# always cache the cos/sin for the full sequence length to avoid
|
# always cache the cos/sin for the full sequence length to avoid
|
||||||
@ -628,6 +627,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
||||||
|
|
||||||
# double the size and add a batch dimension
|
# double the size and add a batch dimension
|
||||||
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0)
|
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1)
|
||||||
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0)
|
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1)
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
@ -1400,11 +1400,11 @@ class FlashCausalLM(Model):
|
|||||||
cache_lengths = [0] * bs
|
cache_lengths = [0] * bs
|
||||||
if max_bs is None:
|
if max_bs is None:
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
if hasattr(self.model, "get_position_ids"):
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
# use model specific position ids for initialization
|
# mrope have position_ids per section, if so repeat n times
|
||||||
position_ids = self.model.get_position_ids(input_ids)
|
if self.model.config.rope_scaling["rope_type"] == "mrope":
|
||||||
else:
|
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths_tensor = (
|
input_lengths_tensor = (
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
Loading…
Reference in New Issue
Block a user