From 57f9685dc3257a5cbfa593ffe5ca531bb3e53149 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 8 Oct 2024 03:15:09 +0800 Subject: [PATCH] enable mllama in intel platform (#2610) Signed-off-by: Wang, Yi A --- .../models/custom_modeling/mllama.py | 84 +++++++++++++------ 1 file changed, 60 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 73536bd6..6e091a74 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -19,7 +19,12 @@ from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn -import flash_attn_2_cuda +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex +else: + import flash_attn_2_cuda from transformers.activations import ACT2FN import torch.nn.functional as F @@ -698,29 +703,60 @@ class MllamaTextCrossAttention(nn.Module): # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) - attn_output = flash_attn_2_cuda.varlen_fwd( - query_states, - key_states, - value_states, - None, - cu_seqlen_q, - cu_seqlen_k, - None, - None, - None, # block_tables - None, - max_q, - max_k, - 0.0, - self.softmax_scale, - False, - causal, # Causal - -1, # window_size_left, - -1, - 0.0, # softcap - False, - None, - )[0] + if SYSTEM == "ipex": + attn_output = torch.empty_like(query_states) + ipex.llm.functional.varlen_attention( + ( + query_states.contiguous() + if query_states.device.type == "xpu" + else query_states + ), + ( + key_states.contiguous() + if key_states.device.type == "xpu" + else key_states + ), + ( + value_states.contiguous() + if value_states.device.type == "xpu" + else value_states + ), + attn_output, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + attn_output = flash_attn_2_cuda.varlen_fwd( + query_states, + key_states, + value_states, + None, + cu_seqlen_q, + cu_seqlen_k, + None, + None, + None, # block_tables + None, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, # Causal + -1, # window_size_left, + -1, + 0.0, # softcap + False, + None, + )[0] attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return attn_output