From f94f2b3e6d371ba57038e8427768124bf4ff14ae Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 24 Jun 2024 22:41:28 +0000 Subject: [PATCH] fix: refactor and move shard_lora_weights logic --- .../text_generation_server/adapters/lora.py | 50 ++++++++++++++++++- .../custom_modeling/flash_llama_modeling.py | 4 +- server/text_generation_server/models/model.py | 46 ----------------- 3 files changed, 51 insertions(+), 49 deletions(-) diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index d176d150..87543be2 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -30,6 +30,49 @@ if TYPE_CHECKING: from text_generation_server.models.model import Model +def get_start_stop_idxs_for_rank(offset, size, rank, world_size): + block_size = size // world_size + start = offset + rank * block_size + stop = offset + (rank + 1) * block_size + return start, stop + + +def shard_on_dim( + t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup +): + world_size = process_group.size() + rank = process_group.rank() + + size = t.shape[dim] + start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) + + if dim == 0: + tensor = t[start:stop] + elif dim == 1: + tensor = t[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + + return tensor + + +def shard_lora_weights( + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + split_dim: int, + process_group: ProcessGroup, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a + ] + + # [r, hidden_size] + weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b] + + return weights_a, weights_b + + @dataclass class LoraConfig(AdapterConfig): r: int @@ -206,7 +249,12 @@ class LoraWeights(AdapterWeights): config.r = padded_rank return LoraWeights( - *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), + *shard_lora_weights( + weights_a=lora_a_list, + weights_b=lora_b_list, + split_dim=0 if model.is_row_parallel(layer_type) else 1, + process_group=model.process_group, + ), config, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f787b846..c48ed268 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -254,14 +254,14 @@ class LlamaMLP(nn.Module): bias=bias, ) else: - prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + prefixes = [f"gate_proj", f"up_proj"] sizes = [ config.intermediate_size, config.intermediate_size, ] gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=prefixes, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=bias, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 8da44273..c90fd38a 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -21,32 +21,6 @@ from loguru import logger BASE_MODEL_ADAPTER_ID = "__base_model__" -def get_start_stop_idxs_for_rank(offset, size, rank, world_size): - block_size = size // world_size - start = offset + rank * block_size - stop = offset + (rank + 1) * block_size - return start, stop - - -def shard_on_dim( - t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup -): - world_size = process_group.size() - rank = process_group.rank() - - size = t.shape[dim] - start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) - - if dim == 0: - tensor = t[start:stop] - elif dim == 1: - tensor = t[:, start:stop] - else: - raise NotImplementedError("Let's make that generic when needed") - - return tensor - - B = TypeVar("B", bound=Batch) @@ -273,26 +247,6 @@ class Model(ABC): self.loaded_adapters.add(adapter_index) - def shard_lora_weights( - self, - weights_a: List[torch.Tensor], - weights_b: List[torch.Tensor], - layer_type: str, - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - # [hidden_size, r] - split_dim = 0 if self.is_row_parallel(layer_type) else 1 - weights_a = [ - shard_on_dim(w, dim=split_dim, process_group=self.process_group) - for w in weights_a - ] - - # [r, hidden_size] - weights_b = [ - shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b - ] - - return weights_a, weights_b - def offload_adapter( self, adapter_parameters: AdapterParameters,