Fix sharding (?)

This commit is contained in:
Daniël de Kok 2024-05-13 09:51:04 +00:00
parent 4ce8b6f0ee
commit 1510461d93
2 changed files with 23 additions and 8 deletions

View File

@ -91,11 +91,28 @@ def load_row(config, transpose: bool, prefix: str, weights, bias: bool):
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
else: else:
bias = None bias = None
return TensorParallelRowLinear( return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group get_linear(weight, bias, config.quantize), process_group=weights.process_group
) )
def load_col(config, transpose: bool, prefix: str, weights, bias: bool):
if transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=0
)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashGPT2Attention(torch.nn.Module): class FlashGPT2Attention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -106,8 +123,8 @@ class FlashGPT2Attention(torch.nn.Module):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
@ -116,7 +133,6 @@ class FlashGPT2Attention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = self.num_heads // weights.process_group.size()
self.query_key_value = _load_qkv( self.query_key_value = _load_qkv(
config, config,
@ -133,10 +149,10 @@ class FlashGPT2Attention(torch.nn.Module):
weights=weights, weights=weights,
bias=True, bias=True,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) )
def forward( def forward(
self, self,
@ -204,7 +220,7 @@ class GPT2MLP(nn.Module):
) )
) )
self.c_fc = load_row( self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, transpose=True, bias=True config, prefix=f"{prefix}.c_fc", weights=weights, transpose=True, bias=True
) )
self.c_proj = load_row( self.c_proj = load_row(
@ -309,7 +325,6 @@ class FlashGPT2Model(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,

View File

@ -87,7 +87,7 @@ class FlashGPT2(FlashCausalLM):
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads, num_kv_heads=model.model.num_heads,
head_size=model.model.head_size, head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,