fix: refactor and move shard_lora_weights logic

This commit is contained in:
drbh 2024-06-24 22:41:28 +00:00
parent c927cffbf7
commit f94f2b3e6d
3 changed files with 51 additions and 49 deletions

View File

@ -30,6 +30,49 @@ if TYPE_CHECKING:
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass @dataclass
class LoraConfig(AdapterConfig): class LoraConfig(AdapterConfig):
r: int r: int
@ -206,7 +249,12 @@ class LoraWeights(AdapterWeights):
config.r = padded_rank config.r = padded_rank
return LoraWeights( return LoraWeights(
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), *shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
),
config, config,
) )

View File

@ -254,14 +254,14 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
) )
else: else:
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] prefixes = [f"gate_proj", f"up_proj"]
sizes = [ sizes = [
config.intermediate_size, config.intermediate_size,
config.intermediate_size, config.intermediate_size,
] ]
gate_up_proj = TensorParallelColumnLinear.load_multi( gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=prefixes, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=bias, bias=bias,

View File

@ -21,32 +21,6 @@ from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -273,26 +247,6 @@ class Model(ABC):
self.loaded_adapters.add(adapter_index) self.loaded_adapters.add(adapter_index)
def shard_lora_weights(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
layer_type: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
split_dim = 0 if self.is_row_parallel(layer_type) else 1
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
for w in weights_a
]
# [r, hidden_size]
weights_b = [
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
]
return weights_a, weights_b
def offload_adapter( def offload_adapter(
self, self,
adapter_parameters: AdapterParameters, adapter_parameters: AdapterParameters,