mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
refine code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
3338b34ba4
commit
7c8694545f
@ -15,6 +15,14 @@ if SYSTEM == "cuda":
|
|||||||
else:
|
else:
|
||||||
punica_sgmv = None
|
punica_sgmv = None
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
from intel_extension_for_pytorch.llm.functional import (
|
||||||
|
bgmv_expand,
|
||||||
|
bgmv_shrink,
|
||||||
|
sgmv_expand,
|
||||||
|
sgmv_shrink,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from text_generation_server.adapters import AdapterBatchData
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
@ -44,9 +52,9 @@ class LoraLinear(nn.Module):
|
|||||||
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
punica_sgmv is not None
|
data is not None
|
||||||
and data is not None
|
and SYSTEM == "ipex"
|
||||||
and data.can_vectorize(self.process_group)
|
or (punica_sgmv is not None and data.can_vectorize(self.process_group))
|
||||||
):
|
):
|
||||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
# The 'result' tensor represents the full output, which can vary in size based on
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
@ -66,145 +74,121 @@ class LoraLinear(nn.Module):
|
|||||||
proj = result
|
proj = result
|
||||||
|
|
||||||
for r, rank_segments in data.rank_data.items():
|
for r, rank_segments in data.rank_data.items():
|
||||||
lora_a_ptr = rank_segments.lora_a_ptr
|
if SYSTEM == "ipex":
|
||||||
lora_b_ptr = rank_segments.lora_b_ptr
|
lora_a_ptr = rank_segments.lora_a_ptr[
|
||||||
|
:, self.layer_id, :
|
||||||
|
].contiguous()
|
||||||
|
lora_b_ptr = rank_segments.lora_b_ptr[
|
||||||
|
:, self.layer_id, :
|
||||||
|
].contiguous()
|
||||||
|
else:
|
||||||
|
lora_a_ptr = rank_segments.lora_a_ptr
|
||||||
|
lora_b_ptr = rank_segments.lora_b_ptr
|
||||||
|
|
||||||
if lora_a_ptr is None or lora_b_ptr is None:
|
if lora_a_ptr is None or lora_b_ptr is None:
|
||||||
raise ValueError("LoRA data is missing")
|
raise ValueError("LoRA data is missing")
|
||||||
|
|
||||||
if data.use_sgmv:
|
if data.use_sgmv:
|
||||||
# Use SGMV for prefill
|
if SYSTEM == "ipex":
|
||||||
v = punica_sgmv.lora_a_sgmv_cutlass(
|
# Use SGMV for prefill
|
||||||
input,
|
seq_len_tensor = (
|
||||||
rank_segments.tmp_shrink,
|
rank_segments.segment_ends - rank_segments.segment_starts
|
||||||
lora_a_ptr,
|
).to(torch.int64)
|
||||||
rank_segments.segment_starts,
|
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
|
||||||
rank_segments.segment_ends,
|
total_tokens = seq_len_tensor.sum()
|
||||||
self.layer_id,
|
v = torch.zeros(
|
||||||
r,
|
(total_tokens, r), dtype=input.dtype, device=input.device
|
||||||
)
|
)
|
||||||
|
bs = seq_len_tensor.shape[0]
|
||||||
|
sgmv_shrink(
|
||||||
|
input,
|
||||||
|
lora_a_ptr,
|
||||||
|
v,
|
||||||
|
b_seq_start_loc,
|
||||||
|
seq_len_tensor,
|
||||||
|
rank_segments.indices,
|
||||||
|
bs,
|
||||||
|
seq_len_tensor.max().item(),
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use SGMV for prefill
|
||||||
|
v = punica_sgmv.lora_a_sgmv_cutlass(
|
||||||
|
input,
|
||||||
|
rank_segments.tmp_shrink,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
v = self.collect_lora_a(v)
|
v = self.collect_lora_a(v)
|
||||||
|
if SYSTEM == "ipex":
|
||||||
punica_sgmv.lora_b_sgmv_cutlass(
|
sgmv_expand(
|
||||||
proj,
|
v,
|
||||||
v,
|
lora_b_ptr,
|
||||||
rank_segments.tmp_expand,
|
proj,
|
||||||
lora_b_ptr,
|
b_seq_start_loc,
|
||||||
rank_segments.segment_starts,
|
seq_len_tensor,
|
||||||
rank_segments.segment_ends,
|
rank_segments.indices,
|
||||||
self.layer_id,
|
bs,
|
||||||
)
|
seq_len_tensor.max().item(),
|
||||||
|
add_inputs=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
punica_sgmv.lora_b_sgmv_cutlass(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
rank_segments.tmp_expand,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Use BGMV for decode
|
# Use BGMV for decode
|
||||||
v = torch.zeros(
|
v = torch.zeros(
|
||||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
)
|
)
|
||||||
# TODO: error with [-1, 0], but not [0, -1]
|
if SYSTEM == "ipex":
|
||||||
punica_sgmv.add_lora_a_bgmv(
|
bgmv_shrink(
|
||||||
v,
|
input,
|
||||||
input,
|
lora_a_ptr,
|
||||||
lora_a_ptr,
|
v,
|
||||||
rank_segments.indices,
|
rank_segments.indices,
|
||||||
self.layer_id,
|
1.0,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# TODO: error with [-1, 0], but not [0, -1]
|
||||||
|
punica_sgmv.add_lora_a_bgmv(
|
||||||
|
v,
|
||||||
|
input,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
v = self.collect_lora_a(v)
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
punica_sgmv.add_lora_b_bgmv(
|
if SYSTEM == "ipex":
|
||||||
proj,
|
bgmv_expand(
|
||||||
v,
|
v,
|
||||||
lora_b_ptr,
|
lora_b_ptr,
|
||||||
rank_segments.indices,
|
proj,
|
||||||
self.layer_id,
|
rank_segments.indices,
|
||||||
)
|
add_inputs=True,
|
||||||
|
)
|
||||||
if end_idx - start_idx != result.shape[1]:
|
else:
|
||||||
result[:, start_idx:end_idx] += proj
|
punica_sgmv.add_lora_b_bgmv(
|
||||||
elif SYSTEM == "ipex" and data is not None:
|
proj,
|
||||||
from intel_extension_for_pytorch.llm.functional import (
|
v,
|
||||||
bgmv_expand,
|
lora_b_ptr,
|
||||||
bgmv_shrink,
|
rank_segments.indices,
|
||||||
sgmv_expand,
|
self.layer_id,
|
||||||
sgmv_shrink,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# In IPEX, we provide the same API for sgmv and bgmv
|
|
||||||
if end_idx - start_idx != result.shape[1]:
|
|
||||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
|
||||||
else:
|
|
||||||
proj = result
|
|
||||||
|
|
||||||
for r, rank_segments in data.rank_data.items():
|
|
||||||
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
|
|
||||||
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
|
|
||||||
|
|
||||||
if lora_a_ptr is None or lora_b_ptr is None:
|
|
||||||
raise ValueError("LoRA data is missing")
|
|
||||||
|
|
||||||
if data.use_sgmv:
|
|
||||||
# Use SGMV for prefill
|
|
||||||
seq_len_tensor = (
|
|
||||||
rank_segments.segment_ends - rank_segments.segment_starts
|
|
||||||
).to(torch.int64)
|
|
||||||
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
|
|
||||||
total_tokens = seq_len_tensor.sum()
|
|
||||||
v = torch.zeros(
|
|
||||||
(total_tokens, r), dtype=input.dtype, device=input.device
|
|
||||||
)
|
|
||||||
bs = seq_len_tensor.shape[0]
|
|
||||||
sgmv_shrink(
|
|
||||||
input,
|
|
||||||
lora_a_ptr,
|
|
||||||
v,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
rank_segments.indices,
|
|
||||||
bs,
|
|
||||||
seq_len_tensor.max().item(),
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
v = self.collect_lora_a(v)
|
|
||||||
|
|
||||||
sgmv_expand(
|
|
||||||
v,
|
|
||||||
lora_b_ptr,
|
|
||||||
proj,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
rank_segments.indices,
|
|
||||||
bs,
|
|
||||||
seq_len_tensor.max().item(),
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use BGMV for decode
|
|
||||||
v = torch.zeros(
|
|
||||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
|
||||||
)
|
|
||||||
# TODO: error with [-1, 0], but not [0, -1]
|
|
||||||
bgmv_shrink(
|
|
||||||
input,
|
|
||||||
lora_a_ptr,
|
|
||||||
v,
|
|
||||||
rank_segments.indices,
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
v = self.collect_lora_a(v)
|
|
||||||
|
|
||||||
bgmv_expand(
|
|
||||||
v,
|
|
||||||
lora_b_ptr,
|
|
||||||
proj,
|
|
||||||
rank_segments.indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if end_idx - start_idx != result.shape[1]:
|
if end_idx - start_idx != result.shape[1]:
|
||||||
result[:, start_idx:end_idx] += proj
|
result[:, start_idx:end_idx] += proj
|
||||||
|
Loading…
Reference in New Issue
Block a user