fix: refactor and move shard_lora_weights logic

This commit is contained in:
drbh 2024-06-24 22:41:28 +00:00
parent c927cffbf7
commit f94f2b3e6d
3 changed files with 51 additions and 49 deletions

View File

@ -30,6 +30,49 @@ if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
@ -206,7 +249,12 @@ class LoraWeights(AdapterWeights):
config.r = padded_rank
return LoraWeights(
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
),
config,
)

View File

@ -254,14 +254,14 @@ class LlamaMLP(nn.Module):
bias=bias,
)
else:
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
prefixes = [f"gate_proj", f"up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=prefixes,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=bias,

View File

@ -21,32 +21,6 @@ from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__"
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
B = TypeVar("B", bound=Batch)
@ -273,26 +247,6 @@ class Model(ABC):
self.loaded_adapters.add(adapter_index)
def shard_lora_weights(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
layer_type: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
split_dim = 0 if self.is_row_parallel(layer_type) else 1
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
for w in weights_a
]
# [r, hidden_size]
weights_b = [
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
]
return weights_a, weights_b
def offload_adapter(
self,
adapter_parameters: AdapterParameters,