mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
lora enable in xpu
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
e32528792c
commit
3338b34ba4
@ -11,9 +11,8 @@ import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
from torch.distributed import ProcessGroup
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchMetadata,
|
||||
@ -128,17 +127,27 @@ class LoraWeights(AdapterWeights):
|
||||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||
|
||||
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
|
||||
self._is_transposed = False
|
||||
if SYSTEM == "ipex":
|
||||
self._use_cutlass_shrink = False
|
||||
# [num_layers, r, hidden_size]
|
||||
weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_a = [
|
||||
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
|
||||
]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]
|
||||
self._weights_b = torch.stack(weights_b)
|
||||
else:
|
||||
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_a = [
|
||||
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous()
|
||||
for w in weights_a
|
||||
]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
|
||||
# [num_layers, r, hidden_size]
|
||||
self._weights_b = torch.stack(weights_b)
|
||||
# [num_layers, r, hidden_size]
|
||||
self._weights_b = torch.stack(weights_b)
|
||||
|
||||
self.adapter_config = adapter_config
|
||||
|
||||
@ -175,7 +184,10 @@ class LoraWeights(AdapterWeights):
|
||||
|
||||
@classmethod
|
||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||
return [BatchLoraWeights]
|
||||
if SYSTEM == "ipex":
|
||||
return [IPEXBatchLoraWeights]
|
||||
else:
|
||||
return [BatchLoraWeights]
|
||||
|
||||
# prepare pre-loaded lora weights for use in the model.
|
||||
#
|
||||
@ -245,17 +257,20 @@ class LoraWeights(AdapterWeights):
|
||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [
|
||||
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
|
||||
]
|
||||
if SYSTEM != "ipex":
|
||||
lora_a_list = [
|
||||
punica_sgmv.pad_rank(w, dim=1, world_size=world_size)
|
||||
for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
punica_sgmv.pad_rank(w, dim=0, world_size=world_size)
|
||||
for w in lora_b_list
|
||||
]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
padded_rank = lora_a_list[0].size(1)
|
||||
config.r = padded_rank
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
padded_rank = lora_a_list[0].size(1)
|
||||
config.r = padded_rank
|
||||
|
||||
return LoraWeights(
|
||||
*shard_lora_weights(
|
||||
@ -471,6 +486,115 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPEXBatchLoraWeights(BatchLoraWeights):
|
||||
@classmethod
|
||||
def load(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Optional["BatchLoraWeights"]:
|
||||
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||
adapter_weights = {
|
||||
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||
}
|
||||
if not adapter_weights:
|
||||
return None
|
||||
|
||||
first_weights = next(iter(adapter_weights.values()))
|
||||
device = first_weights.weights_a.device
|
||||
segment_indices = meta.segment_indices
|
||||
|
||||
lora_a = {
|
||||
idx: adapter_weights[idx].weights_a
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
lora_b = {
|
||||
idx: adapter_weights[idx].weights_b
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
adapter_index_configs = {
|
||||
idx: adapter_weights[idx].adapter_config
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
if len(lora_a) != 0:
|
||||
lora_a_ptr = torch.stack(list(lora_a.values()))
|
||||
if len(lora_b) != 0:
|
||||
lora_b_ptr = torch.stack(list(lora_b.values()))
|
||||
|
||||
use_sgmv = True if prefill else False
|
||||
|
||||
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||
|
||||
rank_indices = defaultdict(list)
|
||||
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||
if adapter_idx not in adapter_weights:
|
||||
continue
|
||||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||
|
||||
if prefill_head_indices is not None:
|
||||
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||
for head_index in prefill_head_indices:
|
||||
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||
if head_index < meta.adapter_segments[j]:
|
||||
prefill_head_segment_ends[-1] += 1
|
||||
else:
|
||||
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||
j += 1
|
||||
|
||||
rank_data = {}
|
||||
segment_starts = None
|
||||
segment_ends = None
|
||||
if use_sgmv:
|
||||
segment_starts = meta.adapter_segments[:-1]
|
||||
segment_ends = meta.adapter_segments[1:]
|
||||
if prefill_head_indices is not None:
|
||||
segment_starts = prefill_head_segment_starts[:-1]
|
||||
segment_ends = prefill_head_segment_ends[1:]
|
||||
batch_indices = [
|
||||
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||
]
|
||||
for rank, indices in rank_indices.items():
|
||||
adapters_indices = []
|
||||
lora_a_keys = list(lora_a.keys())
|
||||
for segment_idx in batch_indices:
|
||||
if segment_idx in indices:
|
||||
adapters_indices.append(
|
||||
lora_a_keys.index(segment_indices[segment_idx])
|
||||
)
|
||||
else:
|
||||
adapters_indices.append(-1)
|
||||
adapters_indices = torch.tensor(
|
||||
adapters_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
if use_sgmv:
|
||||
adapters_indices = adapters_indices[segment_starts]
|
||||
rank_data[rank] = RankSegments(
|
||||
rank=rank,
|
||||
tmp_shrink=None,
|
||||
tmp_expand=None,
|
||||
lora_a_ptr=lora_a_ptr,
|
||||
lora_b_ptr=lora_b_ptr,
|
||||
segment_starts=segment_starts,
|
||||
segment_ends=segment_ends,
|
||||
indices=adapters_indices,
|
||||
)
|
||||
|
||||
return BatchLoraWeights(
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
adapter_index_configs=adapter_index_configs,
|
||||
rank_data=rank_data,
|
||||
use_sgmv=use_sgmv,
|
||||
)
|
||||
|
||||
|
||||
def get_scaling_factor(
|
||||
lora_alpha: int,
|
||||
r: int,
|
||||
|
@ -4,8 +4,8 @@ import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
@ -121,6 +121,91 @@ class LoraLinear(nn.Module):
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
elif SYSTEM == "ipex" and data is not None:
|
||||
from intel_extension_for_pytorch.llm.functional import (
|
||||
bgmv_expand,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
# In IPEX, we provide the same API for sgmv and bgmv
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
|
||||
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
seq_len_tensor = (
|
||||
rank_segments.segment_ends - rank_segments.segment_starts
|
||||
).to(torch.int64)
|
||||
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
v = torch.zeros(
|
||||
(total_tokens, r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
bs = seq_len_tensor.shape[0]
|
||||
sgmv_shrink(
|
||||
input,
|
||||
lora_a_ptr,
|
||||
v,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
rank_segments.indices,
|
||||
bs,
|
||||
seq_len_tensor.max().item(),
|
||||
1.0,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
sgmv_expand(
|
||||
v,
|
||||
lora_b_ptr,
|
||||
proj,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
rank_segments.indices,
|
||||
bs,
|
||||
seq_len_tensor.max().item(),
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
bgmv_shrink(
|
||||
input,
|
||||
lora_a_ptr,
|
||||
v,
|
||||
rank_segments.indices,
|
||||
1.0,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
bgmv_expand(
|
||||
v,
|
||||
lora_b_ptr,
|
||||
proj,
|
||||
rank_segments.indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user