mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix wrong refactor of rope
This commit is contained in:
parent
8617d4795a
commit
a992084b9b
@ -595,21 +595,20 @@ try:
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
|
||||
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
|
||||
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim = cos.shape[-1]
|
||||
|
||||
dtype = x.dtype
|
||||
x_upcast = x.to(torch.float32)
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
dtype = x.dtype
|
||||
x_upcast = x.to(torch.float32)
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
|
||||
x1 = x_upcast[..., :rotary_dim]
|
||||
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
||||
x1 = x_upcast[..., :rotary_dim]
|
||||
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
|
||||
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
||||
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
||||
return x
|
||||
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
|
||||
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
||||
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
||||
return x
|
||||
else:
|
||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user