fix bug on cuda build

This commit is contained in:
Felix Marty 2023-11-23 13:01:19 +00:00
parent 1b0236cb3c
commit bdb6c9d1ed

View File

@ -729,7 +729,7 @@ try:
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1).float(), sin.unsqueeze(1).float()
return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):