fix: refactors and helpful comments

This commit is contained in:
drbh 2024-06-24 22:01:41 +00:00
parent a07b612989
commit 3c9b28eaec
5 changed files with 141 additions and 93 deletions

View File

@ -4,6 +4,7 @@ include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica
unit-tests:
pytest -s -vv -m "not private" tests

View File

@ -85,6 +85,11 @@ def serve(
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
logger.warning(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value

View File

@ -1,12 +1,13 @@
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
@ -17,16 +18,15 @@ from text_generation_server.utils.sgmv import (
orient_for_rank,
)
LORA = "lora"
MEDUSA = "medusa"
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(self, base_layer, layer_id, process_group):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
@ -41,14 +41,24 @@ class LoraLinear(nn.Module):
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get(LORA) if data is not None else None
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
@ -58,55 +68,57 @@ class LoraLinear(nn.Module):
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
if lora_a_ptr is not None and lora_b_ptr is not None:
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
if lora_a_ptr is not None and lora_b_ptr is not None:
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
@ -149,13 +161,27 @@ class LoraLinear(nn.Module):
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
@ -165,6 +191,10 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
@ -178,7 +208,12 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()

View File

@ -56,28 +56,37 @@ if SYSTEM == "rocm":
def load_attention(config, prefix, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
sizes = None
prefixes = None
# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
else:
# otherwise, load the default attention based on the number of heads
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -86,18 +95,13 @@ def load_attention(config, prefix, weights, layer_id):
bias=bias,
)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
class FlashLlamaAttention(torch.nn.Module):
@ -237,34 +241,39 @@ class LlamaMLP(nn.Module):
),
)
)
prefixes = None
sizes = None
# Fuse gate and up proj
bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=bias,
)
else:
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
prefixes=prefixes,
weights=weights,
dim=0,
bias=bias,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
@ -273,15 +282,12 @@ class LlamaMLP(nn.Module):
bias=bias,
)
if config.model_type == "phi3":
self.down_proj = down_proj
else:
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()

View File

@ -1,6 +1,7 @@
import torch
import os
from loguru import logger
from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
@ -30,9 +31,9 @@ def set_model_id(model_id: str):
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = None
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: dict):
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index