mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: remove get_cos_sin_hack dev function
This commit is contained in:
parent
22fdf9344f
commit
ec933282b2
@ -277,32 +277,6 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def get_cos_sin_hack(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
# TODO: avoid always computing, use the cache and update it if necessary
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, None, :, None]
|
||||
.float()
|
||||
.expand(3, position_ids.shape[1], -1, 1)
|
||||
)
|
||||
|
||||
position_ids_expanded = position_ids[
|
||||
:, :, None, :
|
||||
].float() # shape (3, bs, 1, positions)
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
|
||||
2, 3
|
||||
)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype)
|
||||
sin = emb.sin().to(dtype)
|
||||
|
||||
# Update cached values
|
||||
self._cos_cached = cos
|
||||
self._sin_cached = sin
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
|
Loading…
Reference in New Issue
Block a user