mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: refactors and helpful comments
This commit is contained in:
parent
a07b612989
commit
3c9b28eaec
@ -4,6 +4,7 @@ include Makefile-vllm
|
||||
include Makefile-awq
|
||||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
@ -85,6 +85,11 @@ def serve(
|
||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
logger.warning(
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||
)
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
|
@ -1,12 +1,13 @@
|
||||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import init_empty_weights
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.utils.sgmv import (
|
||||
add_lora_a_bgmv,
|
||||
@ -17,16 +18,15 @@ from text_generation_server.utils.sgmv import (
|
||||
orient_for_rank,
|
||||
)
|
||||
|
||||
LORA = "lora"
|
||||
MEDUSA = "medusa"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
def __init__(self, base_layer, layer_id, process_group):
|
||||
def __init__(
|
||||
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||
):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.layer_id = layer_id
|
||||
@ -41,14 +41,24 @@ class LoraLinear(nn.Module):
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get(LORA) if data is not None else None
|
||||
data.get("lora") if data is not None else None
|
||||
)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||
# This approach ensures accurate LoRA application across various layer sizes and
|
||||
# configurations, adapting to different model architectures and parallelization strategies.
|
||||
#
|
||||
# Example scenarios where this is necessary:
|
||||
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||
# 2. We're processing the last segment which might be smaller.
|
||||
# 3. Different projection layers (q, k, v) have different sizes.
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
@ -58,55 +68,57 @@ class LoraLinear(nn.Module):
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
if lora_a_ptr is not None and lora_b_ptr is not None:
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
if lora_a_ptr is not None and lora_b_ptr is not None:
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
@ -149,13 +161,27 @@ class LoraLinear(nn.Module):
|
||||
|
||||
|
||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_names = layer_names
|
||||
self.sizes = sizes
|
||||
|
||||
@classmethod
|
||||
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
|
||||
def load(
|
||||
cls,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
return TensorParallelMultiAdapterLinear(
|
||||
base_layer, layer_id, layer_names, sizes, process_group
|
||||
)
|
||||
@ -165,6 +191,10 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
# noop if no layer names are provided (e.g. for models without adapters)
|
||||
if self.layer_names is None:
|
||||
return result
|
||||
|
||||
# handle models like Bloom that have inputs of shape
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||
@ -178,7 +208,12 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
offset = 0
|
||||
for i, layer_name in enumerate(self.layer_names):
|
||||
start_idx = offset // self.process_group.size()
|
||||
|
||||
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||
# different projection sizes across layers and model architectures.
|
||||
if self.sizes is not None:
|
||||
offset += self.sizes[i]
|
||||
end_idx = offset // self.process_group.size()
|
||||
|
@ -56,28 +56,37 @@ if SYSTEM == "rocm":
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = None
|
||||
prefixes = None
|
||||
|
||||
# if specific model type, load the correct attention
|
||||
if config.model_type == "phi3":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
prefix = f"{prefix}.qkv_proj"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
prefix = f"{prefix}.W_pack"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
@ -86,18 +95,13 @@ def load_attention(config, prefix, weights, layer_id):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer=base_layer,
|
||||
layer_id=layer_id,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
@ -237,34 +241,39 @@ class LlamaMLP(nn.Module):
|
||||
),
|
||||
)
|
||||
)
|
||||
prefixes = None
|
||||
sizes = None
|
||||
|
||||
# Fuse gate and up proj
|
||||
bias = getattr(config, "mlp_bias", False)
|
||||
if config.model_type == "phi3":
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
prefixes=prefixes,
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -273,15 +282,12 @@ class LlamaMLP(nn.Module):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if config.model_type == "phi3":
|
||||
self.down_proj = down_proj
|
||||
else:
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import os
|
||||
from loguru import logger
|
||||
from typing import Dict
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
@ -30,9 +31,9 @@ def set_model_id(model_id: str):
|
||||
# NOTE: eventually we should move this into the router and pass back the
|
||||
# index in all cases.
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX = None
|
||||
ADAPTER_TO_INDEX: Dict[str, int] = None
|
||||
|
||||
|
||||
def set_adapter_to_index(adapter_to_index: dict):
|
||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX = adapter_to_index
|
||||
|
Loading…
Reference in New Issue
Block a user