Proper TP

This commit is contained in:
Nicolas Patry 2024-04-23 13:01:26 +00:00
parent 7d31cb6e75
commit 12e310f2a9
3 changed files with 28 additions and 8 deletions

View File

@ -102,7 +102,7 @@ def load_attention(config, prefix, weights):
bias=False, bias=False,
) )
elif config.model_type == "phi3": elif config.model_type == "phi3":
return TensorParallelColumnLinear.load( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
weights=weights, weights=weights,
@ -265,7 +265,7 @@ class LlamaMLP(nn.Module):
) )
# Fuse gate and up proj # Fuse gate and up proj
if config.model_type == "phi3": if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load( self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config, config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
weights=weights, weights=weights,

View File

@ -696,6 +696,19 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(
prefix, quantize=config.quantize
)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod @classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool): def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""

View File

@ -141,6 +141,12 @@ class Weights:
return weight return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str): def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
""" """
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor already alternating Q,K,V within the main tensor
@ -181,8 +187,8 @@ class Weights:
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // 3 single_size = total_size // blocks
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
@ -192,10 +198,11 @@ class Weights:
block_size = single_size // world_size block_size = single_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
q = slice_[start:stop] tensors = []
k = slice_[start + single_size : stop + single_size] for i in range(blocks):
v = slice_[start + 2 * single_size : stop + 2 * single_size] tensor = slice_[start + i * single_size : stop + i * single_size]
weight = torch.cat([q, k, v], dim=0) tensors.append(tensor)
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype) weight = weight.to(dtype=self.dtype)
return weight return weight