mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: refactor and move shard_lora_weights logic
This commit is contained in:
parent
c927cffbf7
commit
f94f2b3e6d
@ -30,6 +30,49 @@ if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
block_size = size // world_size
|
||||
start = offset + rank * block_size
|
||||
stop = offset + (rank + 1) * block_size
|
||||
return start, stop
|
||||
|
||||
|
||||
def shard_on_dim(
|
||||
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
size = t.shape[dim]
|
||||
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||
|
||||
if dim == 0:
|
||||
tensor = t[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = t[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def shard_lora_weights(
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
split_dim: int,
|
||||
process_group: ProcessGroup,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
# [hidden_size, r]
|
||||
weights_a = [
|
||||
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||
]
|
||||
|
||||
# [r, hidden_size]
|
||||
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||
|
||||
return weights_a, weights_b
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(AdapterConfig):
|
||||
r: int
|
||||
@ -206,7 +249,12 @@ class LoraWeights(AdapterWeights):
|
||||
config.r = padded_rank
|
||||
|
||||
return LoraWeights(
|
||||
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
|
||||
*shard_lora_weights(
|
||||
weights_a=lora_a_list,
|
||||
weights_b=lora_b_list,
|
||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||
process_group=model.process_group,
|
||||
),
|
||||
config,
|
||||
)
|
||||
|
||||
|
@ -254,14 +254,14 @@ class LlamaMLP(nn.Module):
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||
prefixes = [f"gate_proj", f"up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=prefixes,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=bias,
|
||||
|
@ -21,32 +21,6 @@ from loguru import logger
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
block_size = size // world_size
|
||||
start = offset + rank * block_size
|
||||
stop = offset + (rank + 1) * block_size
|
||||
return start, stop
|
||||
|
||||
|
||||
def shard_on_dim(
|
||||
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
size = t.shape[dim]
|
||||
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||
|
||||
if dim == 0:
|
||||
tensor = t[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = t[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
||||
@ -273,26 +247,6 @@ class Model(ABC):
|
||||
|
||||
self.loaded_adapters.add(adapter_index)
|
||||
|
||||
def shard_lora_weights(
|
||||
self,
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
layer_type: str,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
# [hidden_size, r]
|
||||
split_dim = 0 if self.is_row_parallel(layer_type) else 1
|
||||
weights_a = [
|
||||
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
|
||||
for w in weights_a
|
||||
]
|
||||
|
||||
# [r, hidden_size]
|
||||
weights_b = [
|
||||
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
|
||||
]
|
||||
|
||||
return weights_a, weights_b
|
||||
|
||||
def offload_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
|
Loading…
Reference in New Issue
Block a user