diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9c8b2ade..209df95c 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -595,21 +595,20 @@ try: elif IS_ROCM_SYSTEM: # For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm. # We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3 - def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] + rotary_dim = cos.shape[-1] - dtype = x.dtype - x_upcast = x.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) + dtype = x.dtype + x_upcast = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) - x1 = x_upcast[..., :rotary_dim] - x2 = x_upcast[..., rotary_dim : 2 * rotary_dim] + x1 = x_upcast[..., :rotary_dim] + x2 = x_upcast[..., rotary_dim : 2 * rotary_dim] - # Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well. - x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype) - x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype) - return x + # Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well. + x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype) + x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype) + return x else: raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")