mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Proper TP
This commit is contained in:
parent
7d31cb6e75
commit
12e310f2a9
@ -102,7 +102,7 @@ def load_attention(config, prefix, weights):
|
||||
bias=False,
|
||||
)
|
||||
elif config.model_type == "phi3":
|
||||
return TensorParallelColumnLinear.load(
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
@ -265,7 +265,7 @@ class LlamaMLP(nn.Module):
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
if config.model_type == "phi3":
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load(
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
|
@ -696,6 +696,19 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_gate_up(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
|
@ -141,6 +141,12 @@ class Weights:
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 3)
|
||||
|
||||
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||
|
||||
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
@ -181,8 +187,8 @@ class Weights:
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
@ -192,10 +198,11 @@ class Weights:
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[start:stop]
|
||||
k = slice_[start + single_size : stop + single_size]
|
||||
v = slice_[start + 2 * single_size : stop + 2 * single_size]
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
tensors = []
|
||||
for i in range(blocks):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
tensors.append(tensor)
|
||||
weight = torch.cat(tensors, dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
return weight
|
||||
|
Loading…
Reference in New Issue
Block a user