mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix sharding (?)
This commit is contained in:
parent
4ce8b6f0ee
commit
1510461d93
@ -91,11 +91,28 @@ def load_row(config, transpose: bool, prefix: str, weights, bias: bool):
|
|||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
|
|
||||||
return TensorParallelRowLinear(
|
return TensorParallelRowLinear(
|
||||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_col(config, transpose: bool, prefix: str, weights, bias: bool):
|
||||||
|
if transpose:
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||||
|
else:
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
[prefix], quantize=config.quantize, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2Attention(torch.nn.Module):
|
class FlashGPT2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -106,8 +123,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
|
||||||
|
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
@ -116,7 +133,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
f"and `num_shards`: {weights.process_group.size()}"
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.num_key_value_heads = self.num_heads // weights.process_group.size()
|
|
||||||
|
|
||||||
self.query_key_value = _load_qkv(
|
self.query_key_value = _load_qkv(
|
||||||
config,
|
config,
|
||||||
@ -133,10 +149,10 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -204,7 +220,7 @@ class GPT2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.c_fc = load_row(
|
self.c_fc = load_col(
|
||||||
config, prefix=f"{prefix}.c_fc", weights=weights, transpose=True, bias=True
|
config, prefix=f"{prefix}.c_fc", weights=weights, transpose=True, bias=True
|
||||||
)
|
)
|
||||||
self.c_proj = load_row(
|
self.c_proj = load_row(
|
||||||
@ -309,7 +325,6 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
|
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -87,7 +87,7 @@ class FlashGPT2(FlashCausalLM):
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
num_kv_heads=model.model.num_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
Loading…
Reference in New Issue
Block a user