mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: refactor and move shard_lora_weights logic
This commit is contained in:
parent
c927cffbf7
commit
f94f2b3e6d
@ -30,6 +30,49 @@ if TYPE_CHECKING:
|
|||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||||
|
block_size = size // world_size
|
||||||
|
start = offset + rank * block_size
|
||||||
|
stop = offset + (rank + 1) * block_size
|
||||||
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
|
def shard_on_dim(
|
||||||
|
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||||
|
):
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
size = t.shape[dim]
|
||||||
|
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = t[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = t[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def shard_lora_weights(
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
split_dim: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
# [hidden_size, r]
|
||||||
|
weights_a = [
|
||||||
|
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||||
|
]
|
||||||
|
|
||||||
|
# [r, hidden_size]
|
||||||
|
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||||
|
|
||||||
|
return weights_a, weights_b
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoraConfig(AdapterConfig):
|
class LoraConfig(AdapterConfig):
|
||||||
r: int
|
r: int
|
||||||
@ -206,7 +249,12 @@ class LoraWeights(AdapterWeights):
|
|||||||
config.r = padded_rank
|
config.r = padded_rank
|
||||||
|
|
||||||
return LoraWeights(
|
return LoraWeights(
|
||||||
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
|
*shard_lora_weights(
|
||||||
|
weights_a=lora_a_list,
|
||||||
|
weights_b=lora_b_list,
|
||||||
|
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||||
|
process_group=model.process_group,
|
||||||
|
),
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -254,14 +254,14 @@ class LlamaMLP(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
prefixes = [f"gate_proj", f"up_proj"]
|
||||||
sizes = [
|
sizes = [
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
]
|
]
|
||||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=prefixes,
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
@ -21,32 +21,6 @@ from loguru import logger
|
|||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
|
||||||
block_size = size // world_size
|
|
||||||
start = offset + rank * block_size
|
|
||||||
stop = offset + (rank + 1) * block_size
|
|
||||||
return start, stop
|
|
||||||
|
|
||||||
|
|
||||||
def shard_on_dim(
|
|
||||||
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
|
||||||
):
|
|
||||||
world_size = process_group.size()
|
|
||||||
rank = process_group.rank()
|
|
||||||
|
|
||||||
size = t.shape[dim]
|
|
||||||
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
|
||||||
|
|
||||||
if dim == 0:
|
|
||||||
tensor = t[start:stop]
|
|
||||||
elif dim == 1:
|
|
||||||
tensor = t[:, start:stop]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
@ -273,26 +247,6 @@ class Model(ABC):
|
|||||||
|
|
||||||
self.loaded_adapters.add(adapter_index)
|
self.loaded_adapters.add(adapter_index)
|
||||||
|
|
||||||
def shard_lora_weights(
|
|
||||||
self,
|
|
||||||
weights_a: List[torch.Tensor],
|
|
||||||
weights_b: List[torch.Tensor],
|
|
||||||
layer_type: str,
|
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
||||||
# [hidden_size, r]
|
|
||||||
split_dim = 0 if self.is_row_parallel(layer_type) else 1
|
|
||||||
weights_a = [
|
|
||||||
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
|
|
||||||
for w in weights_a
|
|
||||||
]
|
|
||||||
|
|
||||||
# [r, hidden_size]
|
|
||||||
weights_b = [
|
|
||||||
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
|
|
||||||
]
|
|
||||||
|
|
||||||
return weights_a, weights_b
|
|
||||||
|
|
||||||
def offload_adapter(
|
def offload_adapter(
|
||||||
self,
|
self,
|
||||||
adapter_parameters: AdapterParameters,
|
adapter_parameters: AdapterParameters,
|
||||||
|
Loading…
Reference in New Issue
Block a user