From 7c8694545f738b2b90173a68d6ecc938790e8584 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 1 Jul 2025 00:21:49 -0700 Subject: [PATCH] refine code Signed-off-by: Wang, Yi A --- server/text_generation_server/layers/lora.py | 234 +++++++++---------- 1 file changed, 109 insertions(+), 125 deletions(-) diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 5e5a737b..dc5072bb 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -15,6 +15,14 @@ if SYSTEM == "cuda": else: punica_sgmv = None +if SYSTEM == "ipex": + from intel_extension_for_pytorch.llm.functional import ( + bgmv_expand, + bgmv_shrink, + sgmv_expand, + sgmv_shrink, + ) + if TYPE_CHECKING: from text_generation_server.adapters import AdapterBatchData @@ -44,9 +52,9 @@ class LoraLinear(nn.Module): data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) if ( - punica_sgmv is not None - and data is not None - and data.can_vectorize(self.process_group) + data is not None + and SYSTEM == "ipex" + or (punica_sgmv is not None and data.can_vectorize(self.process_group)) ): # In tensor-parallel configurations, each GPU processes a specific segment of the output. # The 'result' tensor represents the full output, which can vary in size based on @@ -66,145 +74,121 @@ class LoraLinear(nn.Module): proj = result for r, rank_segments in data.rank_data.items(): - lora_a_ptr = rank_segments.lora_a_ptr - lora_b_ptr = rank_segments.lora_b_ptr + if SYSTEM == "ipex": + lora_a_ptr = rank_segments.lora_a_ptr[ + :, self.layer_id, : + ].contiguous() + lora_b_ptr = rank_segments.lora_b_ptr[ + :, self.layer_id, : + ].contiguous() + else: + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr if lora_a_ptr is None or lora_b_ptr is None: raise ValueError("LoRA data is missing") if data.use_sgmv: - # Use SGMV for prefill - v = punica_sgmv.lora_a_sgmv_cutlass( - input, - rank_segments.tmp_shrink, - lora_a_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - r, - ) + if SYSTEM == "ipex": + # Use SGMV for prefill + seq_len_tensor = ( + rank_segments.segment_ends - rank_segments.segment_starts + ).to(torch.int64) + b_seq_start_loc = rank_segments.segment_starts.to(torch.int64) + total_tokens = seq_len_tensor.sum() + v = torch.zeros( + (total_tokens, r), dtype=input.dtype, device=input.device + ) + bs = seq_len_tensor.shape[0] + sgmv_shrink( + input, + lora_a_ptr, + v, + b_seq_start_loc, + seq_len_tensor, + rank_segments.indices, + bs, + seq_len_tensor.max().item(), + 1.0, + ) + else: + # Use SGMV for prefill + v = punica_sgmv.lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) if self.process_group.size() > 1: v = self.collect_lora_a(v) - - punica_sgmv.lora_b_sgmv_cutlass( - proj, - v, - rank_segments.tmp_expand, - lora_b_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - ) + if SYSTEM == "ipex": + sgmv_expand( + v, + lora_b_ptr, + proj, + b_seq_start_loc, + seq_len_tensor, + rank_segments.indices, + bs, + seq_len_tensor.max().item(), + add_inputs=True, + ) + else: + punica_sgmv.lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) else: # Use BGMV for decode v = torch.zeros( (input.size(0), r), dtype=input.dtype, device=input.device ) - # TODO: error with [-1, 0], but not [0, -1] - punica_sgmv.add_lora_a_bgmv( - v, - input, - lora_a_ptr, - rank_segments.indices, - self.layer_id, - ) + if SYSTEM == "ipex": + bgmv_shrink( + input, + lora_a_ptr, + v, + rank_segments.indices, + 1.0, + ) + else: + # TODO: error with [-1, 0], but not [0, -1] + punica_sgmv.add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) if self.process_group.size() > 1: v = self.collect_lora_a(v) - punica_sgmv.add_lora_b_bgmv( - proj, - v, - lora_b_ptr, - rank_segments.indices, - self.layer_id, - ) - - if end_idx - start_idx != result.shape[1]: - result[:, start_idx:end_idx] += proj - elif SYSTEM == "ipex" and data is not None: - from intel_extension_for_pytorch.llm.functional import ( - bgmv_expand, - bgmv_shrink, - sgmv_expand, - sgmv_shrink, - ) - - # In IPEX, we provide the same API for sgmv and bgmv - if end_idx - start_idx != result.shape[1]: - proj = torch.zeros_like(result[:, start_idx:end_idx]) - else: - proj = result - - for r, rank_segments in data.rank_data.items(): - lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous() - lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous() - - if lora_a_ptr is None or lora_b_ptr is None: - raise ValueError("LoRA data is missing") - - if data.use_sgmv: - # Use SGMV for prefill - seq_len_tensor = ( - rank_segments.segment_ends - rank_segments.segment_starts - ).to(torch.int64) - b_seq_start_loc = rank_segments.segment_starts.to(torch.int64) - total_tokens = seq_len_tensor.sum() - v = torch.zeros( - (total_tokens, r), dtype=input.dtype, device=input.device - ) - bs = seq_len_tensor.shape[0] - sgmv_shrink( - input, - lora_a_ptr, - v, - b_seq_start_loc, - seq_len_tensor, - rank_segments.indices, - bs, - seq_len_tensor.max().item(), - 1.0, - ) - - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - sgmv_expand( - v, - lora_b_ptr, - proj, - b_seq_start_loc, - seq_len_tensor, - rank_segments.indices, - bs, - seq_len_tensor.max().item(), - add_inputs=True, - ) - else: - # Use BGMV for decode - v = torch.zeros( - (input.size(0), r), dtype=input.dtype, device=input.device - ) - # TODO: error with [-1, 0], but not [0, -1] - bgmv_shrink( - input, - lora_a_ptr, - v, - rank_segments.indices, - 1.0, - ) - - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - bgmv_expand( - v, - lora_b_ptr, - proj, - rank_segments.indices, - add_inputs=True, - ) + if SYSTEM == "ipex": + bgmv_expand( + v, + lora_b_ptr, + proj, + rank_segments.indices, + add_inputs=True, + ) + else: + punica_sgmv.add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj