refine code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-07-01 00:21:49 -07:00
parent 3338b34ba4
commit 7c8694545f

View File

@ -15,6 +15,14 @@ if SYSTEM == "cuda":
else: else:
punica_sgmv = None punica_sgmv = None
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData from text_generation_server.adapters import AdapterBatchData
@ -44,9 +52,9 @@ class LoraLinear(nn.Module):
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if ( if (
punica_sgmv is not None data is not None
and data is not None and SYSTEM == "ipex"
and data.can_vectorize(self.process_group) or (punica_sgmv is not None and data.can_vectorize(self.process_group))
): ):
# In tensor-parallel configurations, each GPU processes a specific segment of the output. # In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on # The 'result' tensor represents the full output, which can vary in size based on
@ -66,6 +74,14 @@ class LoraLinear(nn.Module):
proj = result proj = result
for r, rank_segments in data.rank_data.items(): for r, rank_segments in data.rank_data.items():
if SYSTEM == "ipex":
lora_a_ptr = rank_segments.lora_a_ptr[
:, self.layer_id, :
].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[
:, self.layer_id, :
].contiguous()
else:
lora_a_ptr = rank_segments.lora_a_ptr lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr lora_b_ptr = rank_segments.lora_b_ptr
@ -73,78 +89,7 @@ class LoraLinear(nn.Module):
raise ValueError("LoRA data is missing") raise ValueError("LoRA data is missing")
if data.use_sgmv: if data.use_sgmv:
# Use SGMV for prefill if SYSTEM == "ipex":
v = punica_sgmv.lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
punica_sgmv.lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
punica_sgmv.add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
punica_sgmv.add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
elif SYSTEM == "ipex" and data is not None:
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)
# In IPEX, we provide the same API for sgmv and bgmv
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill # Use SGMV for prefill
seq_len_tensor = ( seq_len_tensor = (
rank_segments.segment_ends - rank_segments.segment_starts rank_segments.segment_ends - rank_segments.segment_starts
@ -166,10 +111,21 @@ class LoraLinear(nn.Module):
seq_len_tensor.max().item(), seq_len_tensor.max().item(),
1.0, 1.0,
) )
else:
# Use SGMV for prefill
v = punica_sgmv.lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1: if self.process_group.size() > 1:
v = self.collect_lora_a(v) v = self.collect_lora_a(v)
if SYSTEM == "ipex":
sgmv_expand( sgmv_expand(
v, v,
lora_b_ptr, lora_b_ptr,
@ -181,12 +137,22 @@ class LoraLinear(nn.Module):
seq_len_tensor.max().item(), seq_len_tensor.max().item(),
add_inputs=True, add_inputs=True,
) )
else:
punica_sgmv.lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else: else:
# Use BGMV for decode # Use BGMV for decode
v = torch.zeros( v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device (input.size(0), r), dtype=input.dtype, device=input.device
) )
# TODO: error with [-1, 0], but not [0, -1] if SYSTEM == "ipex":
bgmv_shrink( bgmv_shrink(
input, input,
lora_a_ptr, lora_a_ptr,
@ -194,10 +160,20 @@ class LoraLinear(nn.Module):
rank_segments.indices, rank_segments.indices,
1.0, 1.0,
) )
else:
# TODO: error with [-1, 0], but not [0, -1]
punica_sgmv.add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1: if self.process_group.size() > 1:
v = self.collect_lora_a(v) v = self.collect_lora_a(v)
if SYSTEM == "ipex":
bgmv_expand( bgmv_expand(
v, v,
lora_b_ptr, lora_b_ptr,
@ -205,6 +181,14 @@ class LoraLinear(nn.Module):
rank_segments.indices, rank_segments.indices,
add_inputs=True, add_inputs=True,
) )
else:
punica_sgmv.add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]: if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj result[:, start_idx:end_idx] += proj