mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix sharding (?)
This commit is contained in:
parent
4ce8b6f0ee
commit
1510461d93
@ -91,11 +91,28 @@ def load_row(config, transpose: bool, prefix: str, weights, bias: bool):
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
def load_col(config, transpose: bool, prefix: str, weights, bias: bool):
|
||||
if transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=0
|
||||
)
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
class FlashGPT2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -106,8 +123,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
@ -116,7 +133,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.query_key_value = _load_qkv(
|
||||
config,
|
||||
@ -133,10 +149,10 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -204,7 +220,7 @@ class GPT2MLP(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
self.c_fc = load_row(
|
||||
self.c_fc = load_col(
|
||||
config, prefix=f"{prefix}.c_fc", weights=weights, transpose=True, bias=True
|
||||
)
|
||||
self.c_proj = load_row(
|
||||
@ -309,7 +325,6 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -87,7 +87,7 @@ class FlashGPT2(FlashCausalLM):
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
Loading…
Reference in New Issue
Block a user