lora enable in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-21 18:24:04 -07:00
parent e32528792c
commit 3338b34ba4
2 changed files with 231 additions and 22 deletions

View File

@ -11,9 +11,8 @@ import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
@ -128,17 +127,27 @@ class LoraWeights(AdapterWeights):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
if SYSTEM == "ipex":
self._use_cutlass_shrink = False
# [num_layers, r, hidden_size]
weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, hidden_size, r]
weights_a = [
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
]
self._weights_a = torch.stack(weights_a)
# [num_layers, hidden_size, r]
weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]
self._weights_b = torch.stack(weights_b)
else:
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
# [num_layers, hidden_size, r]
weights_a = [
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous()
for w in weights_a
]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@ -175,7 +184,10 @@ class LoraWeights(AdapterWeights):
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
if SYSTEM == "ipex":
return [IPEXBatchLoraWeights]
else:
return [BatchLoraWeights]
# prepare pre-loaded lora weights for use in the model.
#
@ -245,17 +257,20 @@ class LoraWeights(AdapterWeights):
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
]
lora_b_list = [
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
]
if SYSTEM != "ipex":
lora_a_list = [
punica_sgmv.pad_rank(w, dim=1, world_size=world_size)
for w in lora_a_list
]
lora_b_list = [
punica_sgmv.pad_rank(w, dim=0, world_size=world_size)
for w in lora_b_list
]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
@ -471,6 +486,115 @@ class BatchLoraWeights(BatchAdapterWeights):
)
@dataclass
class IPEXBatchLoraWeights(BatchLoraWeights):
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
if len(lora_a) != 0:
lora_a_ptr = torch.stack(list(lora_a.values()))
if len(lora_b) != 0:
lora_b_ptr = torch.stack(list(lora_b.values()))
use_sgmv = True if prefill else False
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
segment_starts = None
segment_ends = None
if use_sgmv:
segment_starts = meta.adapter_segments[:-1]
segment_ends = meta.adapter_segments[1:]
if prefill_head_indices is not None:
segment_starts = prefill_head_segment_starts[:-1]
segment_ends = prefill_head_segment_ends[1:]
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
for rank, indices in rank_indices.items():
adapters_indices = []
lora_a_keys = list(lora_a.keys())
for segment_idx in batch_indices:
if segment_idx in indices:
adapters_indices.append(
lora_a_keys.index(segment_indices[segment_idx])
)
else:
adapters_indices.append(-1)
adapters_indices = torch.tensor(
adapters_indices, dtype=torch.int64, device=device
)
if use_sgmv:
adapters_indices = adapters_indices[segment_starts]
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=None,
tmp_expand=None,
lora_a_ptr=lora_a_ptr,
lora_b_ptr=lora_b_ptr,
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=adapters_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,

View File

@ -4,8 +4,8 @@ import torch
import torch.distributed
from torch import nn
from torch.distributed import ProcessGroup
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "cuda":
@ -121,6 +121,91 @@ class LoraLinear(nn.Module):
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
elif SYSTEM == "ipex" and data is not None:
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)
# In IPEX, we provide the same API for sgmv and bgmv
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
seq_len_tensor = (
rank_segments.segment_ends - rank_segments.segment_starts
).to(torch.int64)
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
total_tokens = seq_len_tensor.sum()
v = torch.zeros(
(total_tokens, r), dtype=input.dtype, device=input.device
)
bs = seq_len_tensor.shape[0]
sgmv_shrink(
input,
lora_a_ptr,
v,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
1.0,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
sgmv_expand(
v,
lora_b_ptr,
proj,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
add_inputs=True,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
bgmv_shrink(
input,
lora_a_ptr,
v,
rank_segments.indices,
1.0,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
bgmv_expand(
v,
lora_b_ptr,
proj,
rank_segments.indices,
add_inputs=True,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else: