mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
245 lines
9.0 KiB
Python
245 lines
9.0 KiB
Python
import math
|
|
import os
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed
|
|
from accelerate import init_empty_weights
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from text_generation_server.utils.sgmv import (
|
|
add_lora_a_bgmv,
|
|
add_lora_b_bgmv,
|
|
has_sgmv,
|
|
lora_a_sgmv_cutlass,
|
|
lora_b_sgmv_cutlass,
|
|
orient_for_rank,
|
|
)
|
|
|
|
LORA = "lora"
|
|
MEDUSA = "medusa"
|
|
|
|
if TYPE_CHECKING:
|
|
from text_generation_server.adapters import AdapterBatchData
|
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
|
|
|
|
|
class LoraLinear(nn.Module):
|
|
def __init__(self, base_layer, layer_id, process_group):
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
self.layer_id = layer_id
|
|
self.process_group = process_group
|
|
|
|
def forward_layer_type(
|
|
self,
|
|
result: torch.Tensor,
|
|
input: torch.Tensor,
|
|
adapter_data: "AdapterBatchData",
|
|
layer_type: str,
|
|
start_idx: int,
|
|
end_idx: int,
|
|
) -> torch.Tensor:
|
|
data = adapter_data.data.get(layer_type)
|
|
data: Optional["BatchLoraWeights"] = (
|
|
data.get(LORA) if data is not None else None
|
|
)
|
|
|
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
|
if end_idx - start_idx != result.shape[1]:
|
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
|
else:
|
|
proj = result
|
|
|
|
for r, rank_segments in data.rank_data.items():
|
|
lora_a_ptr = rank_segments.lora_a_ptr
|
|
lora_b_ptr = rank_segments.lora_b_ptr
|
|
|
|
if data.use_sgmv:
|
|
# Use SGMV for prefill
|
|
if lora_a_ptr is not None and lora_b_ptr is not None:
|
|
v = lora_a_sgmv_cutlass(
|
|
input,
|
|
rank_segments.tmp_shrink,
|
|
lora_a_ptr,
|
|
rank_segments.segment_starts,
|
|
rank_segments.segment_ends,
|
|
self.layer_id,
|
|
r,
|
|
)
|
|
|
|
if self.process_group.size() > 1:
|
|
v = self.collect_lora_a(v)
|
|
|
|
lora_b_sgmv_cutlass(
|
|
proj,
|
|
v,
|
|
rank_segments.tmp_expand,
|
|
lora_b_ptr,
|
|
rank_segments.segment_starts,
|
|
rank_segments.segment_ends,
|
|
self.layer_id,
|
|
)
|
|
else:
|
|
# Use BGMV for decode
|
|
if lora_a_ptr is not None and lora_b_ptr is not None:
|
|
v = torch.zeros(
|
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
|
)
|
|
add_lora_a_bgmv(
|
|
v,
|
|
input,
|
|
lora_a_ptr,
|
|
rank_segments.indices,
|
|
self.layer_id,
|
|
)
|
|
|
|
if self.process_group.size() > 1:
|
|
v = self.collect_lora_a(v)
|
|
|
|
add_lora_b_bgmv(
|
|
proj,
|
|
v,
|
|
lora_b_ptr,
|
|
rank_segments.indices,
|
|
self.layer_id,
|
|
)
|
|
|
|
if end_idx - start_idx != result.shape[1]:
|
|
result[:, start_idx:end_idx] += proj
|
|
else:
|
|
for adapter_index in adapter_data.meta.adapter_set:
|
|
if data is not None and data.has_adapter(adapter_index):
|
|
adapter_mask = (
|
|
(adapter_data.meta.adapter_indices == adapter_index)
|
|
.to(input.dtype)
|
|
.view(-1, 1)
|
|
)
|
|
layer_result = self.forward_lora(
|
|
input, data, adapter_index, adapter_mask
|
|
)
|
|
result[:, start_idx:end_idx] += layer_result
|
|
|
|
return result
|
|
|
|
def forward_lora(
|
|
self,
|
|
input: torch.Tensor,
|
|
data: "BatchLoraWeights",
|
|
adapter_index: int,
|
|
adapter_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
|
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
|
|
|
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
|
|
|
a_out = input @ lora_a
|
|
if self.process_group.size() > 1:
|
|
a_out = self.collect_lora_a(a_out)
|
|
|
|
result = (a_out @ lora_b) * adapter_mask
|
|
return result
|
|
|
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError("Implemented in subclasses")
|
|
|
|
|
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
|
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
|
|
super().__init__(base_layer, layer_id, process_group)
|
|
self.layer_names = layer_names
|
|
self.sizes = sizes
|
|
|
|
@classmethod
|
|
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
|
|
return TensorParallelMultiAdapterLinear(
|
|
base_layer, layer_id, layer_names, sizes, process_group
|
|
)
|
|
|
|
def forward(
|
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
|
) -> torch.Tensor:
|
|
result = self.base_layer(input)
|
|
|
|
# handle models like Bloom that have inputs of shape
|
|
# (batch_size, sequence_length, hidden_size)
|
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
|
# for the LoRA computation, then reshape back
|
|
prev_shape = result.shape
|
|
is_3d = len(input.shape) >= 3
|
|
if is_3d:
|
|
input = input.reshape(-1, input.shape[-1])
|
|
result = result.reshape(-1, result.shape[-1])
|
|
|
|
offset = 0
|
|
for i, layer_name in enumerate(self.layer_names):
|
|
start_idx = offset // self.process_group.size()
|
|
|
|
if self.sizes is not None:
|
|
offset += self.sizes[i]
|
|
end_idx = offset // self.process_group.size()
|
|
else:
|
|
end_idx = result.shape[1]
|
|
|
|
result = self.forward_layer_type(
|
|
result, input, adapter_data, layer_name, start_idx, end_idx
|
|
)
|
|
|
|
if is_3d:
|
|
result = result.reshape(prev_shape)
|
|
|
|
return result
|
|
|
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
|
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
|
#
|
|
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
|
gathered_tensors = [
|
|
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
|
]
|
|
torch.distributed.all_gather(gathered_tensors, a_out)
|
|
return torch.cat(gathered_tensors, dim=1)
|
|
|
|
|
|
class TensorParallelAdapterRowLinear(LoraLinear):
|
|
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
|
super().__init__(base_layer, layer_id, process_group)
|
|
self.layer_name = layer_name
|
|
|
|
@classmethod
|
|
def load(cls, base_layer, layer_id, layer_name, process_group):
|
|
return cls(base_layer, layer_id, layer_name, process_group)
|
|
|
|
def forward(
|
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
|
) -> torch.Tensor:
|
|
result = self.base_layer(input)
|
|
|
|
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
|
stride = result.shape[-1] // self.process_group.size()
|
|
start_idx = self.process_group.rank() * stride
|
|
end_idx = (self.process_group.rank() + 1) * stride
|
|
|
|
self.forward_layer_type(
|
|
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
|
)
|
|
|
|
return result
|
|
|
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
|
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
|
#
|
|
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
|
torch.distributed.all_reduce(a_out, group=self.process_group)
|
|
return a_out
|