fix wrong refactor of rope

This commit is contained in:
Felix Marty 2023-11-14 10:43:02 +00:00
parent 8617d4795a
commit a992084b9b

View File

@ -595,7 +595,6 @@ try:
elif IS_ROCM_SYSTEM:
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
dtype = x.dtype