mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: remove get_cos_sin_hack dev function
This commit is contained in:
parent
22fdf9344f
commit
ec933282b2
@ -277,32 +277,6 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
def get_cos_sin_hack(
|
|
||||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
|
||||||
):
|
|
||||||
# TODO: avoid always computing, use the cache and update it if necessary
|
|
||||||
inv_freq_expanded = (
|
|
||||||
self.inv_freq[None, None, :, None]
|
|
||||||
.float()
|
|
||||||
.expand(3, position_ids.shape[1], -1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids_expanded = position_ids[
|
|
||||||
:, :, None, :
|
|
||||||
].float() # shape (3, bs, 1, positions)
|
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
|
|
||||||
2, 3
|
|
||||||
)
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
cos = emb.cos().to(dtype)
|
|
||||||
sin = emb.sin().to(dtype)
|
|
||||||
|
|
||||||
# Update cached values
|
|
||||||
self._cos_cached = cos
|
|
||||||
self._sin_cached = sin
|
|
||||||
|
|
||||||
return cos, sin
|
|
||||||
|
|
||||||
|
|
||||||
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user