mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: refactors and helpful comments
This commit is contained in:
parent
a07b612989
commit
3c9b28eaec
@ -4,6 +4,7 @@ include Makefile-vllm
|
|||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
|
include Makefile-lorax-punica
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
@ -85,6 +85,11 @@ def serve(
|
|||||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(lora_adapter_ids) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||||
|
)
|
||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = None if dtype is None else dtype.value
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from text_generation_server.utils.sgmv import (
|
from text_generation_server.utils.sgmv import (
|
||||||
add_lora_a_bgmv,
|
add_lora_a_bgmv,
|
||||||
@ -17,16 +18,15 @@ from text_generation_server.utils.sgmv import (
|
|||||||
orient_for_rank,
|
orient_for_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
LORA = "lora"
|
|
||||||
MEDUSA = "medusa"
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from text_generation_server.adapters import AdapterBatchData
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||||
|
|
||||||
|
|
||||||
class LoraLinear(nn.Module):
|
class LoraLinear(nn.Module):
|
||||||
def __init__(self, base_layer, layer_id, process_group):
|
def __init__(
|
||||||
|
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
@ -41,14 +41,24 @@ class LoraLinear(nn.Module):
|
|||||||
start_idx: int,
|
start_idx: int,
|
||||||
end_idx: int,
|
end_idx: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if adapter_data is None:
|
|
||||||
return result
|
|
||||||
data = adapter_data.data.get(layer_type)
|
data = adapter_data.data.get(layer_type)
|
||||||
data: Optional["BatchLoraWeights"] = (
|
data: Optional["BatchLoraWeights"] = (
|
||||||
data.get(LORA) if data is not None else None
|
data.get("lora") if data is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
|
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||||
|
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||||
|
# This approach ensures accurate LoRA application across various layer sizes and
|
||||||
|
# configurations, adapting to different model architectures and parallelization strategies.
|
||||||
|
#
|
||||||
|
# Example scenarios where this is necessary:
|
||||||
|
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||||
|
# 2. We're processing the last segment which might be smaller.
|
||||||
|
# 3. Different projection layers (q, k, v) have different sizes.
|
||||||
if end_idx - start_idx != result.shape[1]:
|
if end_idx - start_idx != result.shape[1]:
|
||||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||||
else:
|
else:
|
||||||
@ -58,9 +68,11 @@ class LoraLinear(nn.Module):
|
|||||||
lora_a_ptr = rank_segments.lora_a_ptr
|
lora_a_ptr = rank_segments.lora_a_ptr
|
||||||
lora_b_ptr = rank_segments.lora_b_ptr
|
lora_b_ptr = rank_segments.lora_b_ptr
|
||||||
|
|
||||||
|
if lora_a_ptr is None or lora_b_ptr is None:
|
||||||
|
raise ValueError("LoRA data is missing")
|
||||||
|
|
||||||
if data.use_sgmv:
|
if data.use_sgmv:
|
||||||
# Use SGMV for prefill
|
# Use SGMV for prefill
|
||||||
if lora_a_ptr is not None and lora_b_ptr is not None:
|
|
||||||
v = lora_a_sgmv_cutlass(
|
v = lora_a_sgmv_cutlass(
|
||||||
input,
|
input,
|
||||||
rank_segments.tmp_shrink,
|
rank_segments.tmp_shrink,
|
||||||
@ -85,10 +97,10 @@ class LoraLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use BGMV for decode
|
# Use BGMV for decode
|
||||||
if lora_a_ptr is not None and lora_b_ptr is not None:
|
|
||||||
v = torch.zeros(
|
v = torch.zeros(
|
||||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
)
|
)
|
||||||
|
# TODO: error with [-1, 0], but not [0, -1]
|
||||||
add_lora_a_bgmv(
|
add_lora_a_bgmv(
|
||||||
v,
|
v,
|
||||||
input,
|
input,
|
||||||
@ -149,13 +161,27 @@ class LoraLinear(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
super().__init__(base_layer, layer_id, process_group)
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
self.layer_names = layer_names
|
self.layer_names = layer_names
|
||||||
self.sizes = sizes
|
self.sizes = sizes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
|
def load(
|
||||||
|
cls,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
return TensorParallelMultiAdapterLinear(
|
return TensorParallelMultiAdapterLinear(
|
||||||
base_layer, layer_id, layer_names, sizes, process_group
|
base_layer, layer_id, layer_names, sizes, process_group
|
||||||
)
|
)
|
||||||
@ -165,6 +191,10 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
result = self.base_layer(input)
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
# noop if no layer names are provided (e.g. for models without adapters)
|
||||||
|
if self.layer_names is None:
|
||||||
|
return result
|
||||||
|
|
||||||
# handle models like Bloom that have inputs of shape
|
# handle models like Bloom that have inputs of shape
|
||||||
# (batch_size, sequence_length, hidden_size)
|
# (batch_size, sequence_length, hidden_size)
|
||||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||||
@ -178,7 +208,12 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
|
|||||||
offset = 0
|
offset = 0
|
||||||
for i, layer_name in enumerate(self.layer_names):
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
start_idx = offset // self.process_group.size()
|
start_idx = offset // self.process_group.size()
|
||||||
|
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||||
|
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||||
|
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||||
|
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||||
|
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||||
|
# different projection sizes across layers and model architectures.
|
||||||
if self.sizes is not None:
|
if self.sizes is not None:
|
||||||
offset += self.sizes[i]
|
offset += self.sizes[i]
|
||||||
end_idx = offset // self.process_group.size()
|
end_idx = offset // self.process_group.size()
|
||||||
|
@ -56,28 +56,37 @@ if SYSTEM == "rocm":
|
|||||||
def load_attention(config, prefix, weights, layer_id):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
bias = getattr(config, "attention_bias", False)
|
bias = getattr(config, "attention_bias", False)
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = None
|
||||||
|
prefixes = None
|
||||||
|
|
||||||
# if specific model type, load the correct attention
|
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
prefix = f"{prefix}.qkv_proj"
|
||||||
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
prefix = f"{prefix}.W_pack"
|
||||||
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# otherwise, load the default attention based on the number of heads
|
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
base_layer = TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
@ -86,16 +95,11 @@ def load_attention(config, prefix, weights, layer_id):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
return TensorParallelMultiAdapterLinear.load(
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
base_layer,
|
base_layer=base_layer,
|
||||||
layer_id,
|
layer_id=layer_id,
|
||||||
["q_proj", "k_proj", "v_proj"],
|
layer_names=prefixes,
|
||||||
sizes=[
|
sizes=sizes,
|
||||||
head_size * config.num_attention_heads,
|
|
||||||
head_size * config.num_key_value_heads,
|
|
||||||
head_size * config.num_key_value_heads,
|
|
||||||
],
|
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,19 +241,27 @@ class LlamaMLP(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
prefixes = None
|
||||||
|
sizes = None
|
||||||
|
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
bias = getattr(config, "mlp_bias", False)
|
bias = getattr(config, "mlp_bias", False)
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=prefixes,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -258,11 +270,8 @@ class LlamaMLP(nn.Module):
|
|||||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
gate_up_proj,
|
gate_up_proj,
|
||||||
index,
|
index,
|
||||||
["gate_proj", "up_proj"],
|
layer_names=prefixes,
|
||||||
sizes=[
|
sizes=sizes,
|
||||||
config.intermediate_size,
|
|
||||||
config.intermediate_size,
|
|
||||||
],
|
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -273,9 +282,6 @@ class LlamaMLP(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.model_type == "phi3":
|
|
||||||
self.down_proj = down_proj
|
|
||||||
else:
|
|
||||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
down_proj,
|
down_proj,
|
||||||
index,
|
index,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
@ -30,9 +31,9 @@ def set_model_id(model_id: str):
|
|||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
global ADAPTER_TO_INDEX
|
global ADAPTER_TO_INDEX
|
||||||
ADAPTER_TO_INDEX = None
|
ADAPTER_TO_INDEX: Dict[str, int] = None
|
||||||
|
|
||||||
|
|
||||||
def set_adapter_to_index(adapter_to_index: dict):
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||||
global ADAPTER_TO_INDEX
|
global ADAPTER_TO_INDEX
|
||||||
ADAPTER_TO_INDEX = adapter_to_index
|
ADAPTER_TO_INDEX = adapter_to_index
|
||||||
|
Loading…
Reference in New Issue
Block a user