refine code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-07-01 00:21:49 -07:00
parent 3338b34ba4
commit 7c8694545f

View File

@ -15,6 +15,14 @@ if SYSTEM == "cuda":
else: else:
punica_sgmv = None punica_sgmv = None
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData from text_generation_server.adapters import AdapterBatchData
@ -44,9 +52,9 @@ class LoraLinear(nn.Module):
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if ( if (
punica_sgmv is not None data is not None
and data is not None and SYSTEM == "ipex"
and data.can_vectorize(self.process_group) or (punica_sgmv is not None and data.can_vectorize(self.process_group))
): ):
# In tensor-parallel configurations, each GPU processes a specific segment of the output. # In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on # The 'result' tensor represents the full output, which can vary in size based on
@ -66,145 +74,121 @@ class LoraLinear(nn.Module):
proj = result proj = result
for r, rank_segments in data.rank_data.items(): for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr if SYSTEM == "ipex":
lora_b_ptr = rank_segments.lora_b_ptr lora_a_ptr = rank_segments.lora_a_ptr[
:, self.layer_id, :
].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[
:, self.layer_id, :
].contiguous()
else:
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None: if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing") raise ValueError("LoRA data is missing")
if data.use_sgmv: if data.use_sgmv:
# Use SGMV for prefill if SYSTEM == "ipex":
v = punica_sgmv.lora_a_sgmv_cutlass( # Use SGMV for prefill
input, seq_len_tensor = (
rank_segments.tmp_shrink, rank_segments.segment_ends - rank_segments.segment_starts
lora_a_ptr, ).to(torch.int64)
rank_segments.segment_starts, b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
rank_segments.segment_ends, total_tokens = seq_len_tensor.sum()
self.layer_id, v = torch.zeros(
r, (total_tokens, r), dtype=input.dtype, device=input.device
) )
bs = seq_len_tensor.shape[0]
sgmv_shrink(
input,
lora_a_ptr,
v,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
1.0,
)
else:
# Use SGMV for prefill
v = punica_sgmv.lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1: if self.process_group.size() > 1:
v = self.collect_lora_a(v) v = self.collect_lora_a(v)
if SYSTEM == "ipex":
punica_sgmv.lora_b_sgmv_cutlass( sgmv_expand(
proj, v,
v, lora_b_ptr,
rank_segments.tmp_expand, proj,
lora_b_ptr, b_seq_start_loc,
rank_segments.segment_starts, seq_len_tensor,
rank_segments.segment_ends, rank_segments.indices,
self.layer_id, bs,
) seq_len_tensor.max().item(),
add_inputs=True,
)
else:
punica_sgmv.lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else: else:
# Use BGMV for decode # Use BGMV for decode
v = torch.zeros( v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device (input.size(0), r), dtype=input.dtype, device=input.device
) )
# TODO: error with [-1, 0], but not [0, -1] if SYSTEM == "ipex":
punica_sgmv.add_lora_a_bgmv( bgmv_shrink(
v, input,
input, lora_a_ptr,
lora_a_ptr, v,
rank_segments.indices, rank_segments.indices,
self.layer_id, 1.0,
) )
else:
# TODO: error with [-1, 0], but not [0, -1]
punica_sgmv.add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1: if self.process_group.size() > 1:
v = self.collect_lora_a(v) v = self.collect_lora_a(v)
punica_sgmv.add_lora_b_bgmv( if SYSTEM == "ipex":
proj, bgmv_expand(
v, v,
lora_b_ptr, lora_b_ptr,
rank_segments.indices, proj,
self.layer_id, rank_segments.indices,
) add_inputs=True,
)
if end_idx - start_idx != result.shape[1]: else:
result[:, start_idx:end_idx] += proj punica_sgmv.add_lora_b_bgmv(
elif SYSTEM == "ipex" and data is not None: proj,
from intel_extension_for_pytorch.llm.functional import ( v,
bgmv_expand, lora_b_ptr,
bgmv_shrink, rank_segments.indices,
sgmv_expand, self.layer_id,
sgmv_shrink, )
)
# In IPEX, we provide the same API for sgmv and bgmv
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
seq_len_tensor = (
rank_segments.segment_ends - rank_segments.segment_starts
).to(torch.int64)
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
total_tokens = seq_len_tensor.sum()
v = torch.zeros(
(total_tokens, r), dtype=input.dtype, device=input.device
)
bs = seq_len_tensor.shape[0]
sgmv_shrink(
input,
lora_a_ptr,
v,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
1.0,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
sgmv_expand(
v,
lora_b_ptr,
proj,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
add_inputs=True,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
bgmv_shrink(
input,
lora_a_ptr,
v,
rank_segments.indices,
1.0,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
bgmv_expand(
v,
lora_b_ptr,
proj,
rank_segments.indices,
add_inputs=True,
)
if end_idx - start_idx != result.shape[1]: if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj result[:, start_idx:end_idx] += proj