mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
refine code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
3338b34ba4
commit
7c8694545f
@ -15,6 +15,14 @@ if SYSTEM == "cuda":
|
||||
else:
|
||||
punica_sgmv = None
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.functional import (
|
||||
bgmv_expand,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
@ -44,9 +52,9 @@ class LoraLinear(nn.Module):
|
||||
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||
|
||||
if (
|
||||
punica_sgmv is not None
|
||||
and data is not None
|
||||
and data.can_vectorize(self.process_group)
|
||||
data is not None
|
||||
and SYSTEM == "ipex"
|
||||
or (punica_sgmv is not None and data.can_vectorize(self.process_group))
|
||||
):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
@ -66,145 +74,121 @@ class LoraLinear(nn.Module):
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
if SYSTEM == "ipex":
|
||||
lora_a_ptr = rank_segments.lora_a_ptr[
|
||||
:, self.layer_id, :
|
||||
].contiguous()
|
||||
lora_b_ptr = rank_segments.lora_b_ptr[
|
||||
:, self.layer_id, :
|
||||
].contiguous()
|
||||
else:
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
v = punica_sgmv.lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
if SYSTEM == "ipex":
|
||||
# Use SGMV for prefill
|
||||
seq_len_tensor = (
|
||||
rank_segments.segment_ends - rank_segments.segment_starts
|
||||
).to(torch.int64)
|
||||
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
v = torch.zeros(
|
||||
(total_tokens, r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
bs = seq_len_tensor.shape[0]
|
||||
sgmv_shrink(
|
||||
input,
|
||||
lora_a_ptr,
|
||||
v,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
rank_segments.indices,
|
||||
bs,
|
||||
seq_len_tensor.max().item(),
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Use SGMV for prefill
|
||||
v = punica_sgmv.lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
punica_sgmv.lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
if SYSTEM == "ipex":
|
||||
sgmv_expand(
|
||||
v,
|
||||
lora_b_ptr,
|
||||
proj,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
rank_segments.indices,
|
||||
bs,
|
||||
seq_len_tensor.max().item(),
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
punica_sgmv.lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
punica_sgmv.add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
if SYSTEM == "ipex":
|
||||
bgmv_shrink(
|
||||
input,
|
||||
lora_a_ptr,
|
||||
v,
|
||||
rank_segments.indices,
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
punica_sgmv.add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
punica_sgmv.add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
elif SYSTEM == "ipex" and data is not None:
|
||||
from intel_extension_for_pytorch.llm.functional import (
|
||||
bgmv_expand,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
# In IPEX, we provide the same API for sgmv and bgmv
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
|
||||
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
seq_len_tensor = (
|
||||
rank_segments.segment_ends - rank_segments.segment_starts
|
||||
).to(torch.int64)
|
||||
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
v = torch.zeros(
|
||||
(total_tokens, r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
bs = seq_len_tensor.shape[0]
|
||||
sgmv_shrink(
|
||||
input,
|
||||
lora_a_ptr,
|
||||
v,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
rank_segments.indices,
|
||||
bs,
|
||||
seq_len_tensor.max().item(),
|
||||
1.0,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
sgmv_expand(
|
||||
v,
|
||||
lora_b_ptr,
|
||||
proj,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
rank_segments.indices,
|
||||
bs,
|
||||
seq_len_tensor.max().item(),
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
bgmv_shrink(
|
||||
input,
|
||||
lora_a_ptr,
|
||||
v,
|
||||
rank_segments.indices,
|
||||
1.0,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
bgmv_expand(
|
||||
v,
|
||||
lora_b_ptr,
|
||||
proj,
|
||||
rank_segments.indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
if SYSTEM == "ipex":
|
||||
bgmv_expand(
|
||||
v,
|
||||
lora_b_ptr,
|
||||
proj,
|
||||
rank_segments.indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
punica_sgmv.add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
|
Loading…
Reference in New Issue
Block a user