mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Proper TP
This commit is contained in:
parent
7d31cb6e75
commit
12e310f2a9
@ -102,7 +102,7 @@ def load_attention(config, prefix, weights):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
elif config.model_type == "phi3":
|
elif config.model_type == "phi3":
|
||||||
return TensorParallelColumnLinear.load(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -265,7 +265,7 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load(
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -696,6 +696,19 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
@classmethod
|
||||||
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_gate_up(
|
||||||
|
prefix, quantize=config.quantize
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
@ -141,6 +141,12 @@ class Weights:
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||||
|
return self.get_weights_col_packed(prefix, quantize, 3)
|
||||||
|
|
||||||
|
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||||
|
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||||
|
|
||||||
|
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
||||||
"""
|
"""
|
||||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||||
already alternating Q,K,V within the main tensor
|
already alternating Q,K,V within the main tensor
|
||||||
@ -181,8 +187,8 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||||
single_size = total_size // 3
|
single_size = total_size // blocks
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
@ -192,10 +198,11 @@ class Weights:
|
|||||||
block_size = single_size // world_size
|
block_size = single_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
q = slice_[start:stop]
|
tensors = []
|
||||||
k = slice_[start + single_size : stop + single_size]
|
for i in range(blocks):
|
||||||
v = slice_[start + 2 * single_size : stop + 2 * single_size]
|
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||||
weight = torch.cat([q, k, v], dim=0)
|
tensors.append(tensor)
|
||||||
|
weight = torch.cat(tensors, dim=0)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
weight = weight.to(dtype=self.dtype)
|
weight = weight.to(dtype=self.dtype)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user