refine code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-07-01 00:21:49 -07:00
parent 3338b34ba4
commit 7c8694545f

View File

@ -15,6 +15,14 @@ if SYSTEM == "cuda":
else:
punica_sgmv = None
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
@ -44,9 +52,9 @@ class LoraLinear(nn.Module):
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if (
punica_sgmv is not None
and data is not None
and data.can_vectorize(self.process_group)
data is not None
and SYSTEM == "ipex"
or (punica_sgmv is not None and data.can_vectorize(self.process_group))
):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
@ -66,145 +74,121 @@ class LoraLinear(nn.Module):
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if SYSTEM == "ipex":
lora_a_ptr = rank_segments.lora_a_ptr[
:, self.layer_id, :
].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[
:, self.layer_id, :
].contiguous()
else:
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
v = punica_sgmv.lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if SYSTEM == "ipex":
# Use SGMV for prefill
seq_len_tensor = (
rank_segments.segment_ends - rank_segments.segment_starts
).to(torch.int64)
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
total_tokens = seq_len_tensor.sum()
v = torch.zeros(
(total_tokens, r), dtype=input.dtype, device=input.device
)
bs = seq_len_tensor.shape[0]
sgmv_shrink(
input,
lora_a_ptr,
v,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
1.0,
)
else:
# Use SGMV for prefill
v = punica_sgmv.lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
punica_sgmv.lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
if SYSTEM == "ipex":
sgmv_expand(
v,
lora_b_ptr,
proj,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
add_inputs=True,
)
else:
punica_sgmv.lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
punica_sgmv.add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if SYSTEM == "ipex":
bgmv_shrink(
input,
lora_a_ptr,
v,
rank_segments.indices,
1.0,
)
else:
# TODO: error with [-1, 0], but not [0, -1]
punica_sgmv.add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
punica_sgmv.add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
elif SYSTEM == "ipex" and data is not None:
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)
# In IPEX, we provide the same API for sgmv and bgmv
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
seq_len_tensor = (
rank_segments.segment_ends - rank_segments.segment_starts
).to(torch.int64)
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
total_tokens = seq_len_tensor.sum()
v = torch.zeros(
(total_tokens, r), dtype=input.dtype, device=input.device
)
bs = seq_len_tensor.shape[0]
sgmv_shrink(
input,
lora_a_ptr,
v,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
1.0,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
sgmv_expand(
v,
lora_b_ptr,
proj,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
add_inputs=True,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
bgmv_shrink(
input,
lora_a_ptr,
v,
rank_segments.indices,
1.0,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
bgmv_expand(
v,
lora_b_ptr,
proj,
rank_segments.indices,
add_inputs=True,
)
if SYSTEM == "ipex":
bgmv_expand(
v,
lora_b_ptr,
proj,
rank_segments.indices,
add_inputs=True,
)
else:
punica_sgmv.add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj